diff --git a/.agents/skills/strands-review/SKILL.md b/.agents/skills/strands-review/SKILL.md new file mode 100644 index 0000000000..0defa55b4c --- /dev/null +++ b/.agents/skills/strands-review/SKILL.md @@ -0,0 +1,281 @@ +--- +name: strands-review +description: Local preview of the strands-agents/devtools `/strands review` agent. Body is the upstream Task Reviewer SOP verbatim — do not paraphrase. Use when the user types `/strands-review`, asks for a "strands review" of a PR, or wants to anticipate what the remote `/strands review` GitHub Action will flag. Findings are close but not identical to the remote agent. Strongly prefer running this skill in a fresh-context subagent rather than inline — the SOP is long and reviewer judgment is more reliable when it isn't entangled with the parent conversation's prior context. +source: https://github.com/strands-agents/devtools/blob/main/strands-command/agent-sops/task-reviewer.sop.md +--- + + + +# Task Reviewer SOP + +## Role + +You are a Task Reviewer, and your goal is to review code changes in a pull request and provide constructive feedback to improve code quality, maintainability, and adherence to project standards. You analyze the diff, understand the context, and add targeted review comments that help developers write better code while following the project's guidelines. + +## Steps + +### 1. Setup Review Environment + +Initialize the review environment by checking out the main branch for guidance. + +**Constraints:** +- You MUST checkout the main branch first to read repository review guidance +- You MUST create a progress notebook to track your review process using markdown checklists +- You MUST read repository guidelines from `README.md`, `CONTRIBUTING.md`, and `AGENTS.md` (if present) +- You MUST read API bar raising guidelines from https://github.com/strands-agents/docs/blob/main/team/API_BAR_RAISING.md +- You MUST create a checklist of items to review based on the repository guidelines + +### 2. Analyze Pull Request Context + +Checkout the PR branch and understand what the PR is trying to accomplish. + +**Constraints:** +- You MUST checkout the PR branch to review the actual changes +- You MUST read the pull request description and understand the purpose of the changes +- You MUST note the PR number and branch name in your notebook +- You MUST identify the type of changes (feature, bugfix, refactor, etc.) +- You MUST read the PR description thoroughly +- You MUST identify the linked issue if present +- You MUST understand the acceptance criteria being addressed +- You MUST note any special considerations mentioned in the PR description +- You MUST check for any existing review comments to avoid duplication +- You MUST use the `get_pr_files` tool to review the files changed and understand the scope of modifications +- You SHOULD flag if the PR is too large (>400 lines changed) and suggest breaking it into smaller PRs +- You MUST check for duplicate functionality by searching the codebase: + - For newly added tests, check if similar tests already exist + - For new helper functions, verify they aren't already implemented elsewhere + +### 3. Code Analysis Phase + +Perform a comprehensive analysis of the code changes. + +#### 3.1 Structural Review + +Analyze the overall structure and architecture of the changes. + +**Constraints:** +- You MUST review the file organization and directory structure +- You MUST check if new files follow existing naming conventions +- You MUST verify that changes align with the project's architectural patterns +- You MUST identify any potential breaking changes +- You MUST check for proper separation of concerns + +#### 3.2 API Bar Raising Review + +If the PR introduces or modifies public APIs, evaluate the API design from a customer perspective. + +**Constraints:** +- You MUST check if the PR has `needs-api-review` or `completed-api-review` labels +- You MUST verify the PR includes API documentation in the description: + - Expected use cases for the new feature + - Example code snippets demonstrating usage + - Complete API signatures with default parameter values + - Module exports (what's exported from each module) +- You MUST evaluate the API against SDK tenets (https://github.com/strands-agents/docs/blob/main/team/TENETS.md) and decision records (https://github.com/strands-agents/docs/blob/main/team/DECISIONS.md) +- You MUST verify the API addresses documented use cases +- You MUST check if default parameters/behavior represent the most common usage +- You MUST assess the level of abstraction and extensibility: + - What is customizable and what is not? + - Is it the proper level of abstraction? +- You MUST identify use cases that are not addressed and question why +- You MUST flag if the PR requires API review but lacks the `needs-api-review` label for: + - New public classes or abstractions customers will use + - New primitives or frequently-used functionality + - Changes to existing public API contracts +- You MAY suggest the change scope requires designated API reviewer or team consensus if substantial + +#### 3.3 Code Quality Review + +Examine the code for quality, readability, and maintainability issues. + +**Constraints:** +- You MUST check for language-specific best practices as defined in repository guidelines +- You MUST verify code is readable with clear variable/function names and logical structure +- You MUST check that code is maintainable with modular design and loose coupling +- You MUST check for code complexity and suggest simplifications +- You MUST identify unclear or confusing code patterns +- You MUST verify proper error handling +- You MUST check for potential performance issues +- You MUST verify design decisions are documented (why certain patterns were chosen, alternatives considered, tradeoffs made) + +#### 3.4 Testing Review + +Analyze the test coverage and quality of tests. + +**Constraints:** +- You MUST verify that new functionality has corresponding tests +- You MUST check that tests follow the patterns defined in repository documentation +- You MUST ensure tests are in the correct directories as specified in guidelines +- You MUST check for proper test organization and naming +- You MUST identify missing edge cases or error scenarios +- You MUST verify integration tests are included when appropriate +- You MUST flag tests that assert on individual fields when the full object or shape can be asserted in a single equality check, since per-field assertions silently miss unexpected or regressed fields +- You MAY accept per-field assertions only when a field is non-deterministic or irrelevant to the behavior under test, and the test isolates that field rather than splitting the whole assertion + +### 4. Generate Review Comments + +Create specific, actionable review comments for identified issues. + +**Constraints:** +- You MUST focus on the most impactful improvements first +- You MUST provide specific suggestions rather than vague feedback +- You MUST be concise in your feedback +- You MUST avoid nitpicking on minor style issues (nits) - focus on substantive problems: + - Nits include: comment wording, code organization preferences, bracket/semicolon position, filename conventions + - Substantive issues include: bugs, security vulnerabilities, performance problems, maintainability concerns +- You MUST assume positive intent from the code author +- You MUST categorize feedback as: + - **Critical**: Must be fixed (security, breaking changes, major bugs) + - **Important**: Should be fixed (quality, maintainability, standards) + - **Suggestion**: Nice to have (optimizations, style preferences) +- You MUST be constructive and educational in your feedback +- You MUST prioritize feedback that helps the developer learn and improve +- You MAY skip this step if you have no feedback to provide + +#### 4.1 Comment Structure + +Format review comments to be clear and actionable. + +**Constraints:** +- You MUST be concise - avoid verbose explanations +- You MUST provide specific suggestions +- You MAY reference documentation or standards when applicable +- You SHOULD use this format: + ``` + **Issue**: [Brief description] + **Suggestion**: [Specific recommendation] + ``` + +### 5. Post Review Comments + +Add the review comments to the pull request. + +**Constraints:** +- You MUST use the `add_pr_comment` tool for inline comments on specific lines +- You MUST use the `add_pr_comment` tool with no line number for file-level comments +- You MUST use the `reply_to_review_comment` tool to reply to existing inline comments +- You MUST group related comments when possible +- You MUST avoid overwhelming the author with too many minor comments +- You MUST prioritize the most important feedback +- You MUST be respectful and professional in all comments +- You SHOULD limit to 10-15 comments per review to avoid overwhelming the author +- You MUST focus on improvements and suggestions only +- You MUST NOT add inline comments praising good coding practices + +### 6. Summary Review Comment + +Provide a concise overall summary of the review. + +**Constraints:** +- You MUST create a pull request review using GitHub's review feature +- You MUST provide an overall assessment (Approve, Request Changes, Comment) +- You MUST keep the summary concise, informative, and easy to read +- You MUST NOT repeat information already covered in inline comments +- You MUST focus on high-level themes and patterns, not individual issues +- You MUST use collapsible `
` sections if the summary contains multiple categories or is longer than 5 lines +- You MAY include a brief positive note at the end (1 sentence maximum) +- You SHOULD use this format: + ``` + **Assessment**: [Approve/Request Changes/Comment] + + [Brief high-level summary of review themes - 1-2 sentences] + +
+ Review Categories + + - **[Category]**: [High-level pattern or theme, not specific issues] + - **[Category]**: [High-level pattern or theme, not specific issues] + +
+ + [Optional: Brief positive note - 1 sentence max] + ``` + +## Review Focus Areas + +### Code Quality Priorities + +Focus on substantive issues that impact code quality, not stylistic preferences: + +1. **Functionality**: Does the code work as intended? Are edge cases and error conditions handled? +2. **Readability**: Is the code clear with descriptive names and logical structure? +3. **Maintainability**: Is the code modular, loosely coupled, and easy to modify in the future? +4. **Security**: Are there vulnerabilities or data exposure risks? +5. **Performance**: Are there bottlenecks or inefficient algorithms? +6. **Testing**: Is there comprehensive test coverage including edge cases? +7. **Language Best Practices**: Does it follow language-specific best practices as defined in repository guidelines? +8. **Design Documentation**: Are design decisions, alternatives, and tradeoffs documented? +9. **Dependency Bounds**: Do new or changed dependencies have a supported upper bound to prevent breakage from major version releases? + +## Best Practices + +### Review Efficiency +- Focus on the most impactful issues first +- Provide specific, actionable feedback +- Be concise and avoid verbose explanations +- Reference project standards and documentation when applicable +- Be educational and constructive + +### Communication +- Be respectful and professional +- Assume positive intent from the code author +- Acknowledge good practices +- Explain the reasoning behind feedback +- Provide learning opportunities +- Encourage the developer +- Focus on ideas for improving the system, not criticisms of the author + +### Quality Gates +- Ensure critical issues are marked as blocking +- Verify tests meet repository requirements +- Check language-specific compliance as defined in guidelines +- Validate documentation completeness + +## Troubleshooting + +### Large Pull Requests +If the PR is very large: +- Focus on architectural and design issues first +- Prioritize critical bugs and security issues +- Suggest breaking the PR into smaller pieces if appropriate +- Provide high-level feedback on structure and approach + +### Complex Changes +For complex technical changes: +- Take time to understand the full context +- Ask clarifying questions if needed +- Focus on maintainability and future extensibility +- Verify that the solution aligns with project guidelines + +### Disagreements +If you disagree with the approach: +- Explain your reasoning clearly +- Reference project guidelines and standards +- Suggest alternative approaches +- Be open to discussion and learning diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000..94f480de94 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +* text=auto eol=lf \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index b3898b7f78..19617fcc9e 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -19,6 +19,16 @@ body: required: true - label: "I have searched [./issues](./issues?q=) and there are no duplicates of my issue" required: true + - type: dropdown + id: sdk-language + attributes: + label: SDK Language + description: Which Strands SDK are you using? + options: + - Python + - TypeScript + validations: + required: true - type: input id: strands-version attributes: @@ -28,11 +38,11 @@ body: validations: required: true - type: input - id: python-version + id: language-version attributes: - label: Python Version - description: Which version of Python are you using? - placeholder: e.g., 3.10.5 + label: Language Runtime Version + description: Which version of Python or Node.js are you using? + placeholder: e.g., Python 3.10.5 or Node.js 20.17.0 validations: required: true - type: input @@ -50,8 +60,10 @@ body: description: How did you install Strands? options: - pip + - npm + - yarn + - pnpm - git clone - - binary - other validations: required: true diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 1c67d4e86b..4c4b04753f 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,6 +1,7 @@ blank_issues_enabled: false contact_links: - name: Strands Agents SDK Support + # Only one repo has Discussions enabled; point users there. url: https://github.com/strands-agents/sdk-python/discussions about: Please ask and answer questions here - name: Strands Agents SDK Documentation diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 01406b0314..f33ff1d14f 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -11,6 +11,28 @@ updates: dev-dependencies: patterns: - "pytest" + - package-ecosystem: "npm" + directory: "/" + schedule: + interval: "daily" + open-pull-requests-limit: 100 + commit-message: + prefix: "ci(typescript)" + cooldown: + default-days: 5 + semver-major-days: 30 + semver-minor-days: 7 + semver-patch-days: 3 + groups: + development-dependencies: + dependency-type: "development" + applies-to: version-updates + production-minor: + dependency-type: "production" + applies-to: version-updates + update-types: + - "minor" + - "patch" - package-ecosystem: "npm" directory: "/site" schedule: @@ -18,6 +40,12 @@ updates: open-pull-requests-limit: 100 commit-message: prefix: "ci(docs)" + - package-ecosystem: "pip" + directory: "/strands-py-wasm" + schedule: + interval: "daily" + commit-message: + prefix: "ci(python)" - package-ecosystem: "github-actions" directory: "/" schedule: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index dcd162540a..27ef0ef36e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,6 +19,7 @@ jobs: contents: read outputs: python: ${{ steps.filter.outputs.python }} + typescript: ${{ steps.filter.outputs.typescript }} docs: ${{ steps.filter.outputs.docs }} steps: - uses: actions/checkout@v6 @@ -30,6 +31,16 @@ jobs: - 'strands-py/**' - '.github/workflows/python-*' - '.github/workflows/ci.yml' + typescript: + - 'strands-ts/**' + - 'strands-wasm/**' + - 'strands-py-wasm/**' + - 'strandly/**' + - 'wit/**' + - 'package.json' + - 'package-lock.json' + - '.github/workflows/typescript-*' + - '.github/workflows/ci.yml' docs: - 'site/**' - '.github/workflows/docs-*' @@ -47,6 +58,14 @@ jobs: secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + typescript: + name: TypeScript + needs: detect-changes + if: needs.detect-changes.outputs.typescript == 'true' + uses: ./.github/workflows/typescript-pr-and-push.yml + permissions: + contents: read + docs: name: Docs needs: detect-changes @@ -58,10 +77,10 @@ jobs: ci-gate: name: CI Gate if: always() - needs: [detect-changes, python, docs] + needs: [detect-changes, python, typescript, docs] runs-on: ubuntu-latest steps: - uses: re-actors/alls-green@release/v1 with: - allowed-skips: python, docs + allowed-skips: python, typescript, docs jobs: ${{ toJSON(needs) }} diff --git a/.github/workflows/typescript-integration-test.yml b/.github/workflows/typescript-integration-test.yml new file mode 100644 index 0000000000..41f64ebb48 --- /dev/null +++ b/.github/workflows/typescript-integration-test.yml @@ -0,0 +1,81 @@ +name: "TypeScript: Secure Integration Test" + +on: + pull_request_target: + branches: [main] + paths: + - 'strands-ts/**' + - 'strands-wasm/**' + - 'strands-py-wasm/**' + - 'strandly/**' + - 'wit/**' + - 'package.json' + - 'package-lock.json' + - '.github/workflows/typescript-*' + merge_group: + types: [checks_requested] +jobs: + authorization-check: + name: Check access + permissions: read-all + runs-on: ubuntu-latest + outputs: + approval-env: ${{ steps.auth.outputs.approval-env }} + steps: + - name: Check Authorization + id: auth + uses: strands-agents/devtools/authorization-check@main + with: + skip-check: ${{ github.event_name == 'merge_group' }} + username: ${{ github.event.pull_request.user.login || 'invalid' }} + allowed-roles: 'triage,write,maintain,admin' + + run-integration-tests: + name: Run integration tests + runs-on: ubuntu-latest + needs: authorization-check + environment: ${{ needs.authorization-check.outputs.approval-env }} + permissions: + id-token: write + pull-requests: read + contents: read + steps: + - name: Configure Credentials + uses: aws-actions/configure-aws-credentials@v6 + with: + role-to-assume: ${{ secrets.AWS_ROLE_ARN }} + aws-region: us-east-1 + mask-aws-account-id: true + + - name: Checkout head commit + uses: actions/checkout@v6 + with: + ref: ${{ github.event.pull_request.head.sha }} # Pull the commit from the forked repo + persist-credentials: false # Don't persist credentials for subsequent actions + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: 22 + package-manager-cache: false + + - name: Install dependencies + run: | + npm ci + npm run test:browser:install + + - name: Build the package + run: npm run build + + - name: Run integration tests + run: npm run test:integ:all + + - name: Upload test artifacts + if: always() + uses: actions/upload-artifact@v7 + with: + name: test-artifacts-integ + path: ./strands-ts/test/.artifacts/ + retention-days: 4 + include-hidden-files: true # needed because the path has a directory starting with a '.' + if-no-files-found: ignore diff --git a/.github/workflows/typescript-npm-publish-on-release.yml b/.github/workflows/typescript-npm-publish-on-release.yml new file mode 100644 index 0000000000..7df77a820c --- /dev/null +++ b/.github/workflows/typescript-npm-publish-on-release.yml @@ -0,0 +1,91 @@ +name: "TypeScript: Publish NPM Package" + +on: + release: + types: + - published + +jobs: + call-ts-check: + if: startsWith(github.event.release.tag_name, 'typescript/v') + uses: ./.github/workflows/typescript-ts-check.yml + permissions: + contents: read + with: + ref: ${{ github.event.release.target_commitish }} + + call-ts-test: + if: startsWith(github.event.release.tag_name, 'typescript/v') + uses: ./.github/workflows/typescript-ts-test.yml + permissions: + contents: read + with: + ref: ${{ github.event.release.target_commitish }} + + publish: + name: Build and publish to NPM + if: startsWith(github.event.release.tag_name, 'typescript/v') + needs: + - call-ts-check + - call-ts-test + runs-on: ubuntu-latest + + environment: + name: npm + url: https://www.npmjs.com/package/@strands-agents/sdk + permissions: + id-token: write + contents: read + + steps: + - uses: actions/checkout@v6 + with: + persist-credentials: false + + - name: Set up Node.js + uses: actions/setup-node@v6 + with: + node-version: '20' + registry-url: 'https://registry.npmjs.org' + package-manager-cache: false + + - name: Update npm to latest + run: npm install -g npm@latest + + - name: Extract version from tag + id: version + run: | + VERSION=${GITHUB_REF#refs/tags/typescript/v} + echo "version=$VERSION" >> $GITHUB_OUTPUT + echo "Extracted version: $VERSION" + + - name: Validate version + run: | + if [[ ${{ steps.version.outputs.version }} =~ ^[0-9]+\.[0-9]+\.[0-9]+(-[a-zA-Z0-9.]+)?$ ]]; then + echo "Valid version format" + exit 0 + else + echo "Invalid version format" + exit 1 + fi + + - name: Update package.json version + working-directory: strands-ts + run: | + npm version ${{ steps.version.outputs.version }} --no-git-tag-version + + - name: Install dependencies + run: npm ci + + - name: Build + run: npm run build + + - name: Store the distribution packages + uses: actions/upload-artifact@v7 + with: + name: npm-package-distributions + path: ./strands-ts + + - name: Publish to NPM + working-directory: strands-ts + run: npm publish --access public --tag latest diff --git a/.github/workflows/typescript-pr-and-push.yml b/.github/workflows/typescript-pr-and-push.yml new file mode 100644 index 0000000000..764954aae8 --- /dev/null +++ b/.github/workflows/typescript-pr-and-push.yml @@ -0,0 +1,45 @@ +name: "TypeScript: Pull Request and Push" + +on: + workflow_call: + push: + branches: [ main ] + paths: + - 'strands-ts/**' + - 'strands-wasm/**' + - 'strands-py-wasm/**' + - 'strandly/**' + - 'wit/**' + - 'package.json' + - 'package-lock.json' + - '.github/workflows/typescript-*' + workflow_dispatch: + +jobs: + call-security-audit: + uses: ./.github/workflows/typescript-security-audit.yml + permissions: + contents: read + with: + ref: ${{ github.event.pull_request.head.sha || github.sha }} + + call-ts-check: + uses: ./.github/workflows/typescript-ts-check.yml + permissions: + contents: read + with: + ref: ${{ github.event.pull_request.head.sha || github.sha }} + + call-ts-test: + uses: ./.github/workflows/typescript-ts-test.yml + permissions: + contents: read + with: + ref: ${{ github.event.pull_request.head.sha || github.sha }} + + call-test-package-pack: + uses: ./.github/workflows/typescript-test-package-pack.yml + permissions: + contents: read + with: + ref: ${{ github.event.pull_request.head.sha || github.sha }} diff --git a/.github/workflows/typescript-security-audit.yml b/.github/workflows/typescript-security-audit.yml new file mode 100644 index 0000000000..7d31cb97e8 --- /dev/null +++ b/.github/workflows/typescript-security-audit.yml @@ -0,0 +1,33 @@ +name: "TypeScript: Security Audit" + +on: + workflow_call: + inputs: + ref: + required: true + type: string + +jobs: + security-audit: + name: NPM Security Audit + permissions: + contents: read + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v6 + with: + ref: ${{ inputs.ref }} + persist-credentials: false + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: 22 + + - name: Install dependencies + run: npm ci + + - name: Run security audit + run: npm audit --audit-level=high diff --git a/.github/workflows/typescript-test-package-pack.yml b/.github/workflows/typescript-test-package-pack.yml new file mode 100644 index 0000000000..0adc5f7b88 --- /dev/null +++ b/.github/workflows/typescript-test-package-pack.yml @@ -0,0 +1,73 @@ +# End-to-end package install smoke test. +# +# At RC.0 the SDK published a main-entry re-export from an optional peer +# dependency. A real user's `npm install @strands-agents/sdk` succeeded, then +# crashed at module load because the optional peer was not installed. +# +# The existing `test:package` doesn't catch this: it resolves the SDK via +# `file:../../..` and the monorepo's root `node_modules` hoists every optional +# peer (they're devDependencies of the repo), silently satisfying the bad +# re-export. This workflow reproduces a real user's install: `npm pack` the +# SDK, install the tarball in a tempdir OUTSIDE the monorepo tree so nothing +# hoists, then type-check and run a consumer script that touches the public +# surface. A missing optional peer fails at module load the same way it would +# for an end user. +name: "TypeScript: Test Package Pack" + +on: + workflow_call: + inputs: + ref: + required: true + type: string + +jobs: + test-package-pack: + name: Pack Install + permissions: + contents: read + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v6 + with: + ref: ${{ inputs.ref }} + persist-credentials: false + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: 20 + package-manager-cache: false + + - name: Install dependencies + run: npm ci + + - name: Build + run: npm run build + + - name: Pack, install in tmpdir, type-check, and run + # The tempdir MUST live outside the monorepo — otherwise npm hoists + # devDependencies from the repo root into resolution and silently + # satisfies missing optional peers, defeating the test. + run: | + set -euo pipefail + TARBALL=$(cd strands-ts && npm pack --silent) + TARBALL_PATH="$GITHUB_WORKSPACE/strands-ts/$TARBALL" + echo "Packed: $TARBALL_PATH" + + WORK=$(mktemp -d) + trap 'rm -rf "$WORK"' EXIT + echo "Workspace: $WORK" + cp strands-ts/test/packages/npm-pack/package.json \ + strands-ts/test/packages/npm-pack/tsconfig.json \ + strands-ts/test/packages/npm-pack/verify.ts \ + "$WORK/" + cd "$WORK" + + npm install --ignore-scripts --no-audit --no-fund + npm install --ignore-scripts --no-audit --no-fund "$TARBALL_PATH" + + npx tsc --noEmit + npx tsx verify.ts diff --git a/.github/workflows/typescript-ts-check.yml b/.github/workflows/typescript-ts-check.yml new file mode 100644 index 0000000000..27accf3da2 --- /dev/null +++ b/.github/workflows/typescript-ts-check.yml @@ -0,0 +1,46 @@ +name: "TypeScript: Code Quality" + +on: + workflow_call: + inputs: + ref: + required: true + type: string + +jobs: + code-quality: + name: Code Quality + permissions: + contents: read + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v6 + with: + ref: ${{ inputs.ref }} + persist-credentials: false + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: 22 + package-manager-cache: false + + - name: Install dependencies + run: npm ci + + - name: Build + run: npm run build + + - name: Run linting + run: npm run lint + + - name: Check code formatting + run: npm run format:check + + - name: Run type checking + run: npm run type-check + + - name: Verify browser bundle + run: npm run check:browser-bundle diff --git a/.github/workflows/typescript-ts-test.yml b/.github/workflows/typescript-ts-test.yml new file mode 100644 index 0000000000..c0e71cf946 --- /dev/null +++ b/.github/workflows/typescript-ts-test.yml @@ -0,0 +1,58 @@ +name: "TypeScript: Test" + +on: + workflow_call: + inputs: + ref: + required: true + type: string + +jobs: + test: + name: Test (Node ${{ matrix.node-version }} on ${{ matrix.os }}) + permissions: + contents: read + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + node-version: [20, 22, 24] + os: [ubuntu-latest, windows-latest, macos-latest] + + steps: + - name: Checkout code + uses: actions/checkout@v6 + with: + ref: ${{ inputs.ref }} + persist-credentials: false + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: ${{ matrix.node-version }} + package-manager-cache: false + + - name: Install dependencies + run: npm ci + + - name: Install Playwright browsers + run: npm run test:browser:install + + - name: Run unit tests + run: npm run test:all:coverage + + - name: Upload test artifacts + if: always() + uses: actions/upload-artifact@v7 + with: + name: test-artifacts-${{ matrix.node-version }}-${{ matrix.os }} + path: ./strands-ts/test/.artifacts/ + include-hidden-files: true # needed because the path has a directory starting with a '.' + retention-days: 4 + if-no-files-found: ignore + + - name: Build package + run: npm run build + + - name: Test packaging + run: npm run test:package diff --git a/.gitignore b/.gitignore index 596d54cd3c..cfd543ac0d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,11 +1,84 @@ -.DS_Store +# Dependencies +node_modules/ +npm-debug.log* +yarn-debug.log* +yarn-error.log* + +# Test lock files +**/test/packages/**/package-lock.json + +# Build outputs +dist/ +build/ +*.tsbuildinfo +target/ + +# Python +__pycache__/ __pycache__* -.env -.venv -*.bak -.vscode +.pycache/ +.pytest_cache/ +.ruff_cache/ +.venv/ +*.egg-info/ +*.pyd +*.so +*.dylib +*.pdb + +# Generated type bindings (committed only for strands-py-wasm/src/strands/_generated.py) +strands-ts/generated/ +strands-wasm/generated/ + +# Coverage reports +coverage/ + +# IDE files +.claude/* +!.claude/skills +!.claude/references .kiro/* !.kiro/skills !.kiro/references +.vscode/ +.idea/ +*.swp +*.swo +.gradle/ + +# OS files +.DS_Store +Thumbs.db + +# Logs +*.log + +# Environment files +.env +.env.local +.env.development.local +.env.test.local +.env.production.local + +# Github workflow artifacts +.artifact + +# Test artifacts +**/test/.artifacts + +# Misc +*.bak +*.backup +**/mutants.out*/ +bin/ + +# LLM CLAUDE.md .claude/settings.local.json + +# dev +.vitest* + +# Files copied into strands-ts/ during prepack (originals live at repo root) +strands-ts/README.md +strands-ts/LICENSE diff --git a/.husky/pre-commit b/.husky/pre-commit new file mode 100755 index 0000000000..cb5235ff0b --- /dev/null +++ b/.husky/pre-commit @@ -0,0 +1,26 @@ +echo "Running pre-commit checks..." + +# Build (required for integ type-check: workspace symlink resolves to dist/) +echo "Building..." +npm run build || { echo "Build failed. Commit aborted."; exit 1; } + +# Run tests +echo "Running tests..." +npm run test:coverage || { echo "Tests failed. Commit aborted."; exit 1; } +# WASM tests disabled until componentize-js stream support lands and +# entry.ts re-binds to the per-domain WIT layout. CI's ts-test job runs +# the TS SDK suite; WASM tests are tracked under phase-1 bridge tasks. + +# Run linting +echo "Running linting..." +npm run lint || { echo "Linting failed. Commit aborted."; exit 1; } + +# Check formatting +echo "Checking code formatting..." +npm run format:check || { echo "Formatting check failed. Run 'npm run format' to fix. Commit aborted."; exit 1; } + +# Type checking +echo "Running type checks..." +npm run type-check || { echo "Type checking failed. Commit aborted."; exit 1; } + +echo "All pre-commit checks passed!" diff --git a/.node-version b/.node-version new file mode 100644 index 0000000000..829e9737e4 --- /dev/null +++ b/.node-version @@ -0,0 +1 @@ +20.19.0 \ No newline at end of file diff --git a/.prettierrc b/.prettierrc new file mode 100644 index 0000000000..623325469e --- /dev/null +++ b/.prettierrc @@ -0,0 +1,7 @@ +{ + "semi": false, + "singleQuote": true, + "printWidth": 120, + "tabWidth": 2, + "trailingComma": "es5" +} diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000000..1c16778b78 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,1075 @@ +# Agent Development Guide - Strands TypeScript SDK + +This document provides guidance specifically for AI agents working on the Strands TypeScript SDK codebase. For human contributor guidelines, see [CONTRIBUTING.md](CONTRIBUTING.md). + +## Purpose and Scope + +**AGENTS.md** contains agent-specific repository information including: + +- Directory structure with summaries of what is included in each directory +- Development workflow instructions for agents to follow when developing features +- Coding patterns and testing patterns to follow when writing code +- Style guidelines, organizational patterns, and best practices + +**For human contributors**: See [CONTRIBUTING.md](CONTRIBUTING.md) for setup, testing, and contribution guidelines. + +## Directory Structure + +The repo is an npm workspace monorepo. The root `package.json` delegates all build/test/lint commands to the `strands-ts` workspace package. + +``` +sdk-typescript/ +├── strands-ts/ # SDK workspace package +│ ├── src/ # All production code +│ │ ├── __fixtures__/ # Shared test fixtures (mocks, helpers) +│ │ ├── __tests__/ # Unit tests for root-level source files +│ │ │ +│ │ ├── a2a/ # Agent-to-agent protocol +│ │ │ ├── __tests__/ +│ │ │ ├── a2a-agent.ts # A2A agent client +│ │ │ ├── adapters.ts # Strands/A2A type converters +│ │ │ ├── events.ts # A2A streaming events +│ │ │ ├── executor.ts # A2A executor +│ │ │ ├── express-server.ts # Express-based A2A server +│ │ │ ├── logging.ts # A2A-specific logging +│ │ │ ├── server.ts # A2A server base +│ │ │ └── index.ts +│ │ │ +│ │ ├── agent/ # Agent loop and streaming +│ │ │ ├── __tests__/ +│ │ │ ├── agent.ts # Core agent implementation +│ │ │ ├── agent-as-tool.ts # Wrap agent as a tool +│ │ │ ├── printer.ts # Agent output printing +│ │ │ ├── snapshot.ts # Agent state snapshots +│ │ │ └── tool-caller.ts # Direct tool calling via agent.tool accessor +│ │ │ +│ │ ├── conversation-manager/ # Conversation history strategies +│ │ │ ├── __tests__/ +│ │ │ ├── conversation-manager.ts +│ │ │ ├── null-conversation-manager.ts +│ │ │ ├── sliding-window-conversation-manager.ts +│ │ │ ├── summarizing-conversation-manager.ts +│ │ │ └── index.ts +│ │ │ +│ │ ├── hooks/ # Hooks system for extensibility +│ │ │ ├── __tests__/ +│ │ │ ├── events.ts +│ │ │ ├── registry.ts +│ │ │ ├── types.ts +│ │ │ └── index.ts +│ │ │ +│ │ ├── logging/ # Structured logging +│ │ │ ├── __tests__/ +│ │ │ ├── logger.ts +│ │ │ ├── warn-once.ts # Dedupe warnings by message content +│ │ │ ├── types.ts +│ │ │ └── index.ts +│ │ │ +│ │ ├── models/ # Model provider implementations +│ │ │ ├── __tests__/ +│ │ │ ├── google/ # Google Gemini provider +│ │ │ │ ├── adapters.ts +│ │ │ │ ├── errors.ts +│ │ │ │ ├── model.ts +│ │ │ │ ├── types.ts +│ │ │ │ └── index.ts +│ │ │ ├── openai/ # OpenAI provider (Chat Completions + Responses API) +│ │ │ │ ├── __tests__/ # Unit tests (chat.test.ts, responses.test.ts) +│ │ │ │ ├── chat-adapter.ts +│ │ │ │ ├── responses-adapter.ts +│ │ │ │ ├── formatting.ts +│ │ │ │ ├── errors.ts +│ │ │ │ ├── model.ts +│ │ │ │ ├── types.ts +│ │ │ │ └── index.ts +│ │ │ ├── anthropic.ts # Anthropic Claude +│ │ │ ├── bedrock.ts # AWS Bedrock +│ │ │ ├── vercel.ts # Vercel AI SDK +│ │ │ ├── defaults.ts # Centralized model defaults + warning messages +│ │ │ ├── model.ts # Base model interface +│ │ │ └── streaming.ts # Streaming event types +│ │ │ +│ │ ├── multiagent/ # Multi-agent orchestration +│ │ │ ├── __tests__/ +│ │ │ ├── graph.ts # Graph orchestrator (DAG) +│ │ │ ├── swarm.ts # Swarm orchestrator (handoff) +│ │ │ ├── multiagent.ts # Base multi-agent class +│ │ │ ├── nodes.ts # Node types +│ │ │ ├── state.ts # State management +│ │ │ ├── events.ts # Streaming events +│ │ │ ├── edge.ts # Edge definitions +│ │ │ ├── queue.ts # Execution queue +│ │ │ ├── snapshot.ts # Multi-agent snapshots +│ │ │ ├── plugins.ts # Multi-agent plugins +│ │ │ └── index.ts +│ │ │ +│ │ ├── interventions/ # Intervention system for authorization, guardrails, steering +│ │ │ ├── __tests__/ +│ │ │ ├── actions.ts +│ │ │ ├── handler.ts +│ │ │ ├── registry.ts +│ │ │ └── index.ts +│ │ │ +│ │ ├── plugins/ # Plugin system +│ │ │ ├── __tests__/ +│ │ │ ├── plugin.ts +│ │ │ ├── registry.ts +│ │ │ ├── model-plugin.ts # Clears agent messages after invocation when model is stateful +│ │ │ └── index.ts +│ │ │ +│ │ ├── registry/ # Tool registry +│ │ │ ├── __tests__/ +│ │ │ └── tool-registry.ts +│ │ │ +│ │ ├── retry/ # Retry strategies for model calls +│ │ │ ├── __tests__/ +│ │ │ ├── backoff-strategy.ts +│ │ │ ├── model-retry-strategy.ts # Abstract ModelRetryStrategy base class +│ │ │ ├── default-model-retry-strategy.ts +│ │ │ ├── retry-strategy.ts # RetryStrategy union type + dedup helper +│ │ │ └── index.ts +│ │ │ +│ │ ├── sandbox/ # Sandbox abstraction for agent code execution +│ │ │ ├── __tests__/ +│ │ │ ├── base.ts # Abstract Sandbox class +│ │ │ ├── posix-shell.ts # PosixShellSandbox with shell-based defaults +│ │ │ ├── stream-process.ts # ChildProcess-to-AsyncGenerator bridge +│ │ │ ├── constants.ts # Language validation pattern +│ │ │ └── types.ts # ExecutionResult, StreamChunk, FileInfo, OutputFile +│ │ │ +│ │ ├── session/ # Session management +│ │ │ ├── __tests__/ +│ │ │ ├── session-manager.ts +│ │ │ ├── storage.ts # Storage interface +│ │ │ ├── file-storage.ts # File-based storage +│ │ │ ├── s3-storage.ts # S3 storage +│ │ │ ├── types.ts +│ │ │ ├── validation.ts +│ │ │ └── index.ts +│ │ │ +│ │ ├── telemetry/ # OpenTelemetry tracing and metrics +│ │ │ ├── __tests__/ +│ │ │ ├── tracer.ts +│ │ │ ├── meter.ts +│ │ │ ├── config.ts +│ │ │ ├── json.ts +│ │ │ ├── types.ts +│ │ │ ├── utils.ts +│ │ │ └── index.ts +│ │ │ +│ │ ├── tools/ # Tool definitions and types +│ │ │ ├── __tests__/ +│ │ │ ├── function-tool.ts +│ │ │ ├── mcp-tool.ts +│ │ │ ├── noop-tool.ts +│ │ │ ├── structured-output-tool.ts +│ │ │ ├── tool-factory.ts +│ │ │ ├── tool.ts +│ │ │ ├── zod-tool.ts +│ │ │ ├── zod-utils.ts +│ │ │ └── types.ts +│ │ │ +│ │ ├── types/ # Core type definitions +│ │ │ ├── __tests__/ +│ │ │ ├── agent.ts +│ │ │ ├── citations.ts +│ │ │ ├── elicitation.ts +│ │ │ ├── interrupt.ts +│ │ │ ├── json.ts +│ │ │ ├── lifecycle-observer.ts +│ │ │ ├── media.ts +│ │ │ ├── messages.ts +│ │ │ ├── serializable.ts +│ │ │ ├── snapshot.ts +│ │ │ └── validation.ts +│ │ │ +│ │ ├── utils/ # Shared utility functions +│ │ │ └── shell-quote.ts # Shell-safe string escaping +│ │ │ +│ │ ├── vended-interventions/ # Optional vended intervention handlers +│ │ │ ├── hitl/ # Human-in-the-loop approval handler +│ │ │ │ ├── __tests__/ +│ │ │ │ ├── hitl.ts +│ │ │ │ └── index.ts +│ │ │ └── steering/ # Steering handler base + LLM-driven steering +│ │ │ ├── __tests__/ +│ │ │ ├── handlers/ +│ │ │ │ ├── handler.ts +│ │ │ │ └── llm.ts +│ │ │ ├── providers/ +│ │ │ │ ├── context-provider.ts +│ │ │ │ └── tool-ledger.ts +│ │ │ └── index.ts +│ │ │ +│ │ ├── vended-plugins/ # Optional vended plugins +│ │ │ ├── index.ts # Barrel export for all plugins +│ │ │ ├── context-offloader/ # Context offloading plugin +│ │ │ │ ├── __tests__/ +│ │ │ │ ├── plugin.ts +│ │ │ │ ├── storage.ts +│ │ │ │ └── index.ts +│ │ │ └── skills/ # AgentSkills plugin +│ │ │ ├── __tests__/ +│ │ │ ├── agent-skills.ts +│ │ │ ├── skill.ts +│ │ │ └── index.ts +│ │ │ +│ │ ├── vended-tools/ # Optional vended tools +│ │ │ ├── index.ts # Barrel export for all tools +│ │ │ ├── bash/ +│ │ │ ├── file-editor/ +│ │ │ ├── http-request/ +│ │ │ └── notebook/ +│ │ │ +│ │ ├── errors.ts # Custom error classes +│ │ ├── index.ts # Main SDK entry point +│ │ ├── interrupt.ts # Interrupt handling +│ │ ├── mcp.ts # MCP client implementation +│ │ ├── mcp-config.ts # MCP config file parsing +│ │ ├── mime.ts # MIME type utilities +│ │ └── state-store.ts # State store implementation +│ │ +│ ├── generated/ # Auto-generated WIT type declarations +│ │ ├── interfaces/ # Per-interface type definitions +│ │ └── strands:agent.d.ts # Top-level WIT agent declaration +│ │ +│ ├── test/ # Tests outside of source +│ │ ├── integ/ # Integration tests +│ │ │ ├── __fixtures__/ # Integration test fixtures +│ │ │ ├── __resources__/ # Static resources for integration tests +│ │ │ ├── a2a/ +│ │ │ ├── conversation-manager/ +│ │ │ ├── mcp/ +│ │ │ ├── models/ +│ │ │ │ └── openai/ +│ │ │ ├── multiagent/ +│ │ │ ├── skills/ +│ │ │ ├── tools/ +│ │ │ └── agent.test.ts +│ │ └── packages/ # Package compatibility tests (CJS/ESM) +│ │ +│ ├── examples/ # Example applications +│ │ ├── agents-as-tools/ +│ │ ├── browser-agent/ +│ │ ├── first-agent/ +│ │ ├── graph/ +│ │ ├── mcp/ +│ │ ├── swarm/ +│ │ └── telemetry/ +│ │ +│ ├── package.json # SDK package config and dependencies +│ ├── tsconfig.base.json # TypeScript configuration +│ ├── vitest.config.ts # Testing configuration +│ └── eslint.config.js # Linting configuration +│ +├── strands-wasm/ # WASM build tooling +│ ├── __fixtures__/ # Vitest module mocks for WIT imports +│ ├── __tests__/ # Unit tests for entry.ts internals +│ ├── generated/ # Auto-generated WIT type declarations +│ │ └── interfaces/ # Per-interface type definitions +│ ├── test/ # Tests outside of source +│ │ └── guest/ # Tests that load the compiled WASM component +│ ├── docs/ # WASM-specific documentation +│ ├── patches/ # Runtime patches for WASM compatibility +│ │ └── getChunkedStream.js +│ ├── entry.ts # WASM entry point (TS SDK surface for WASM compilation) +│ ├── build.js # Build script for WASM compilation +│ ├── package.json # WASM package configuration +│ ├── vitest.config.ts # Test configuration (unit + guest projects) +│ └── tsconfig.json # TypeScript type-check configuration +│ +├── strands-py-wasm/ # Python SDK bindings (WASM-based) +│ ├── strands/ # Python package source +│ │ ├── _generated/ # Auto-generated type bindings +│ │ ├── agent/ # Agent implementation +│ │ │ └── conversation_manager/ +│ │ ├── event_loop/ # Event loop and retry logic +│ │ ├── models/ # Model providers (Bedrock, Anthropic, OpenAI, Gemini) +│ │ ├── multiagent/ # Multi-agent orchestration (Graph, Swarm) +│ │ ├── session/ # Session management (file, S3) +│ │ ├── tools/ # Tool definitions and MCP client +│ │ │ └── mcp/ +│ │ ├── types/ # Type definitions +│ │ ├── _conversions.py # Type conversions between TS and Python +│ │ ├── _wasm_host.py # WASM host runtime bridge +│ │ ├── hooks.py # Hooks system +│ │ └── interrupt.py # Interrupt handling +│ ├── scripts/ # Build/codegen scripts +│ │ └── generate_types.py # Type generation from WIT definitions +│ ├── examples/ # Example applications +│ ├── tests_integ/ # Integration tests +│ ├── pyproject.toml # Python package configuration +│ └── pyrightconfig.json # Python type checking configuration +│ +├── strandly/ # Developer CLI tooling +│ ├── scripts/ +│ │ └── generate_types.py # Type generation script +│ ├── src/ +│ │ └── cli.ts # CLI entry point +│ ├── package.json # Dev CLI package configuration +│ └── tsconfig.json # TypeScript configuration +│ +├── wit/ # WebAssembly Interface Type definitions +│ ├── deps/ # WIT dependency interfaces +│ │ ├── clocks/clocks.wit +│ │ └── io/io.wit +│ ├── agent.wit # Top-level WIT world definition +│ ├── conversation.wit # Conversation management interfaces +│ ├── logging.wit # Logging interfaces +│ ├── mcp.wit # MCP protocol interfaces +│ ├── messages.wit # Message type definitions +│ ├── models.wit # Model provider interfaces +│ ├── multiagent.wit # Multi-agent interfaces +│ ├── retry.wit # Retry strategy interfaces +│ ├── sessions.wit # Session management interfaces +│ ├── streaming.wit # Streaming event interfaces +│ ├── tools.wit # Tool interfaces +│ └── vended.wit # Vended plugin/tool interfaces +│ +├── dev-docs/ # Project documentation +│ ├── TESTING.md # Comprehensive testing guidelines +│ ├── DEPENDENCIES.md # Dependency management guidelines +│ ├── DIVERGENCES.md # Divergences from Python SDK +│ └── PR.md # Pull request guidelines and template +│ +├── .github/ # GitHub configuration +│ ├── ISSUE_TEMPLATE/ # Issue templates (bug report, feature request) +│ ├── PULL_REQUEST_TEMPLATE.md # PR template +│ └── workflows/ # CI/CD workflows +│ +├── .husky/ # Git hooks (pre-commit checks) +│ +├── package.json # Root workspace config (delegates to strands-ts) +├── .prettierrc # Code formatting configuration +├── .gitignore # Git ignore rules +│ +├── AGENTS.md # This file (agent guidance) +├── COMPATIBILITY.MD # Compatibility documentation +├── CONTRIBUTING.md # Human contributor guidelines +└── README.md # Project overview and usage +``` + +### Directory Purposes + +- **`strands-ts/`**: The SDK workspace package containing all source, tests, and examples +- **`strands-ts/src/`**: All production code with co-located unit tests +- **`strands-ts/src/__fixtures__/`**: Shared test fixtures (mock models, helpers) +- **`strands-ts/src/a2a/`**: Agent-to-agent protocol (A2A client, server, adapters, logging) +- **`strands-ts/src/agent/`**: Agent loop coordination, output printing, snapshots +- **`strands-ts/src/conversation-manager/`**: Conversation history management strategies +- **`strands-ts/src/hooks/`**: Hooks system for event-driven extensibility +- **`strands-ts/src/logging/`**: Structured logging utilities +- **`strands-ts/src/models/`**: Model provider implementations (Bedrock, Anthropic, OpenAI, Google, Vercel) +- **`strands-ts/src/multiagent/`**: Multi-agent orchestration patterns (Graph for DAG execution, Swarm for handoff-based routing) +- **`strands-ts/src/plugins/`**: Plugin system for extending agent functionality +- **`strands-ts/src/registry/`**: Tool registry implementation +- **`strands-ts/src/retry/`**: Retry strategies for model calls (backoff strategies, abstract `ModelRetryStrategy` plugin base class, concrete `DefaultModelRetryStrategy`) +- **`strands-ts/src/sandbox/`**: Sandbox abstraction for agent code execution (abstract `Sandbox` base class, `PosixShellSandbox` base for shell-based implementations) +- **`strands-ts/src/session/`**: Session management (file, S3, custom storage) +- **`strands-ts/src/telemetry/`**: OpenTelemetry tracing and metrics +- **`strands-ts/src/tools/`**: Tool definitions, types, and structured output validation with Zod schemas +- **`strands-ts/src/types/`**: Core type definitions used across the SDK +- **`strands-ts/src/utils/`**: Shared utility functions +- **`strands-ts/src/vended-interventions/`**: Optional vended intervention handlers (hitl, steering — not part of core SDK, independently importable) +- **`strands-ts/src/vended-plugins/`**: Optional vended plugins (context-offloader, skills — not part of core SDK, independently importable) +- **`strands-ts/src/vended-tools/`**: Optional vended tools (bash, file-editor, http-request, notebook) +- **`strands-ts/generated/`**: Auto-generated WIT interface type declarations +- **`strands-ts/test/integ/`**: Integration tests (tests public API and external integrations) +- **`strands-ts/examples/`**: Example applications +- **`strands-wasm/`**: WASM build tooling for compiling the TS SDK to WebAssembly +- **`strands-wasm/generated/`**: Auto-generated WIT interface type declarations for WASM +- **`strands-wasm/test/guest/`**: Tests that load the compiled WASM component +- **`strands-wasm/docs/`**: WASM-specific development documentation +- **`strands-py-wasm/`**: Python SDK bindings powered by the TS SDK compiled to WASM +- **`strands-py-wasm/strands/`**: Python package source with agent, models, multiagent, session, tools, and type modules +- **`strands-py-wasm/scripts/`**: Build and codegen scripts (type generation from WIT definitions) +- **`strands-py-wasm/tests_integ/`**: Python integration tests +- **`strandly/`**: Developer CLI tooling for local development workflows (install on PATH via `npm install && npm link -w strandly`, then call `strandly …`) +- **`wit/`**: WebAssembly Interface Type (WIT) definitions defining the contract between the TS SDK and WASM hosts +- **`wit/deps/`**: External WIT dependency interfaces (clocks, io) +- **`dev-docs/`**: Project documentation (testing guidelines, dependency management, divergences, PR guidelines) +- **`.github/workflows/`**: CI/CD automation and quality gates + +**IMPORTANT**: After making changes that affect the directory structure (adding new directories, moving files, or adding significant new files), you MUST update this directory structure section to reflect the current state of the repository. + +## Development Workflow for Agents + +### 1. Environment Setup + +See [CONTRIBUTING.md - Development Environment](CONTRIBUTING.md#development-environment) for: + +- Prerequisites (Node.js 20+, npm) +- Installation steps +- Verification commands + +### 2. Making Changes + +1. **Create feature branch**: `git checkout -b agent-tasks/{ISSUE_NUMBER}` +2. **Implement changes** following the patterns below +3. **Run quality checks** before committing (pre-commit hooks will run automatically) +4. **Commit with conventional commits**: `feat:`, `fix:`, `refactor:`, `docs:`, etc. +5. **Push to remote**: `git push origin agent-tasks/{ISSUE_NUMBER}` +6. **Create pull request** following [PR.md](dev-docs/PR.md) guidelines + +### 3. Pull Request Guidelines + +When creating pull requests, you **MUST** follow the guidelines in [PR.md](dev-docs/PR.md). Key principles: + +- **Focus on WHY**: Explain motivation and user impact, not implementation details +- **Document public API changes**: Show before/after code examples +- **Be concise**: Use prose over bullet lists; avoid exhaustive checklists +- **Target senior engineers**: Assume familiarity with the SDK +- **Exclude implementation details**: Leave these to code comments and diffs + +See [PR.md](dev-docs/PR.md) for the complete guidance and template. + +### 4. Quality Gates + +Pre-commit hooks automatically run: + +- Build (via npm run build, required for workspace type resolution) +- Unit tests with coverage (via npm run test:coverage) +- WASM unit tests (via npm run test -w strands-wasm) +- Linting (via npm run lint) +- Format checking (via npm run format:check) +- Type checking (via npm run type-check) + +All checks must pass before commit is allowed. + +### 5. Testing Guidelines + +When writing tests, you **MUST** follow the guidelines in [dev-docs/TESTING.md](dev-docs/TESTING.md). Key topics covered: + +- Test organization and file location +- Test batching strategy +- Object assertion best practices +- Test coverage requirements +- Multi-environment testing (Node.js and browser) + +See [TESTING.md](dev-docs/TESTING.md) for the complete testing reference. + +## Coding Patterns and Best Practices + +### Logging Style Guide + +The SDK uses a structured logging format consistent with the Python SDK for better log parsing and searchability. + +**Format**: + +```typescript +// With context fields +logger.warn(`field1=<${value1}>, field2=<${value2}> | human readable message`) + +// Without context fields +logger.warn('human readable message') + +// Multiple statements in message (use pipe to separate) +logger.warn(`field=<${value}> | statement one | statement two`) +``` + +**Guidelines**: + +1. **Context Fields** (when relevant): + - Add context as `field=` pairs at the beginning + - Use commas to separate pairs + - Enclose values in `<>` for readability (especially helpful for empty values: `field=<>`) + - Use template literals for string interpolation + +2. **Messages**: + - Add human-readable messages after context fields + - Use lowercase for consistency + - Avoid punctuation (periods, exclamation points) to reduce clutter + - Keep messages concise and focused on a single statement + - If multiple statements are needed, separate them with pipe character (`|`) + +**Examples**: + +```typescript +// Good: Context fields with message +logger.warn(`stop_reason=<${stopReason}>, fallback=<${fallback}> | unknown stop reason, converting to camelCase`) +logger.warn(`event_type=<${eventType}> | unsupported bedrock event type`) + +// Good: Simple message without context fields +logger.warn('cache points are not supported in openai system prompts, ignoring cache points') + +// Good: Multiple statements separated by pipes +logger.warn(`request_id=<${id}> | processing request | starting validation`) + +// Bad: Not using angle brackets for values +logger.warn(`stop_reason=${stopReason} | unknown stop reason`) + +// Bad: Using punctuation +logger.warn(`event_type=<${eventType}> | Unsupported event type.`) +``` + +### Import Organization + +Use relative imports for internal modules: + +```typescript +// Good: Relative imports for internal modules +import { hello } from './hello' +import { Agent } from '../agent' + +// Good: External dependencies +import { something } from 'external-package' +``` + +### File Organization Pattern + +**For source files**: + +``` +strands-ts/src/ +├── module.ts # Source file +└── __tests__/ + └── module.test.ts # Unit tests co-located +``` + +**Function ordering within files**: + +- Functions MUST be ordered from most general to most specific (top-down reading) +- Public/exported functions MUST appear before private helper functions +- Main entry point functions MUST be at the top of the file +- Helper functions SHOULD follow in order of their usage + +**Example**: + +```typescript +// Good: Main function first, helpers follow +export async function* mainFunction() { + const result = await helperFunction1() + return helperFunction2(result) +} + +async function helperFunction1() { + // Implementation +} + +function helperFunction2(input: string) { + // Implementation +} + +// Bad: Helpers before main function +async function helperFunction1() { + // Implementation +} + +export async function* mainFunction() { + const result = await helperFunction1() + return helperFunction2(result) +} +``` + +**For integration tests**: + +``` +strands-ts/test/integ/ +└── feature.test.ts # Tests public API +``` + +### TypeScript Type Safety + +**Optional chaining for null safety**: Prefer optional chaining over verbose `typeof` checks when accessing potentially undefined properties: + +```typescript +// Good: Optional chaining +return globalThis?.process?.env?.API_KEY + +// Bad: Verbose typeof checks +if (typeof process !== 'undefined' && typeof process.env !== 'undefined') { + return process.env.API_KEY +} +return undefined +``` + +**Strict requirements**: + +```typescript +// Good: Explicit return types +export function process(input: string): string { + return input.toUpperCase() +} + +// Bad: No return type +export function process(input: string) { + return input.toUpperCase() +} + +// Good: Proper typing +export function getData(): { id: number; name: string } { + return { id: 1, name: 'test' } +} + +// Bad: Using any +export function getData(): any { + return { id: 1, name: 'test' } +} +``` + +**Rules**: + +- Always provide explicit return types +- Never use `any` type (enforced by ESLint) +- Use TypeScript strict mode features +- Leverage type inference where appropriate + +### Class Field Naming Conventions + +**Private fields**: Use underscore prefix for private class fields to improve readability and distinguish them from public members. + +```typescript +// Good: Private fields with underscore prefix +export class Example { + private readonly _config: Config + private _state: State + + constructor(config: Config) { + this._config = config + this._state = { initialized: false } + } + + public getConfig(): Config { + return this._config + } +} + +// Bad: No underscore for private fields +export class Example { + private readonly config: Config // Missing underscore + + constructor(config: Config) { + this.config = config + } +} +``` + +**Rules**: + +- Private fields MUST use underscore prefix (e.g., `_field`) +- Public fields MUST NOT use underscore prefix +- This convention improves code readability and makes the distinction between public and private members immediately visible + +#### Naming Conventions for New Features + +When choosing names and constants that match an existing implementation in the Python SDK, use exactly the same literal used +in the Python SDK. Wherever we can achieve compatibility, keep the previous convention. + +#### Plugin Naming + +Name plugins for what they do, not for the `Plugin` interface they implement. + +```typescript +// Good +export class AgentSkills implements Plugin { ... } +export class DefaultModelRetryStrategy implements Plugin { ... } + +// Bad +export class AgentSkillsPlugin implements Plugin { ... } +export class DefaultModelRetryStrategyPlugin implements Plugin { ... } +``` + +Same rule for the associated config (`AgentSkillsConfig`, not `AgentSkillsPluginConfig`). + +### Documentation Requirements + +**TSDoc format** (required for all exported functions): + +````typescript +/** + * Brief description of what the function does. + * + * @param paramName - Description of the parameter + * @param optionalParam - Description of optional parameter + * @returns Description of what is returned + * + * @example + * ```typescript + * const result = functionName('input') + * console.log(result) // "output" + * ``` + */ +export function functionName(paramName: string, optionalParam?: number): string { + // Implementation +} +```` + +**Interface property documentation**: + +```typescript +/** + * Interface description. + */ +export interface MyConfig { + /** + * Single-line description of the property. + */ + propertyName: string + + /** + * Single-line description with optional reference link. + * @see https://docs.example.com/property-details + */ + anotherProperty?: number +} +``` + +**Requirements**: + +- All exported functions, classes, and interfaces must have TSDoc +- Include `@param` for all parameters +- Include `@returns` for return values +- Include `@example` only for exported classes (main SDK entry points like BedrockModel, Agent) +- Do NOT include `@example` for type definitions, interfaces, or internal types +- Interface properties MUST have single-line descriptions +- Interface properties MAY include an optional `@see` link for additional details +- TSDoc validation enforced by ESLint + +### Code Style Guidelines + +**Formatting** (enforced by Prettier): + +- No semicolons +- Single quotes +- Line length: 120 characters +- Tab width: 2 spaces +- Trailing commas in ES5 style + +**Example**: + +```typescript +export function example(name: string, options?: Options): Result { + const config = { + name, + enabled: true, + settings: { + timeout: 5000, + retries: 3, + }, + } + + return processConfig(config) +} +``` + +### Import Organization + +Organize imports in this order: + +```typescript +// 1. External dependencies +import { something } from 'external-package' + +// 2. Internal modules (using relative paths) +import { Agent } from '../agent' +import { Tool } from '../tools' + +// 3. Types (if separate) +import type { Options, Config } from '../types' +``` + +### Interface and Type Organization + +**When defining interfaces or types, organize them so the top-level interface comes first, followed by its dependencies, and then all nested dependencies.** + +```typescript +// Correct - Top-level first, then dependencies +export interface Message { + role: Role + content: ContentBlock[] +} + +export type Role = 'user' | 'assistant' + +export type ContentBlock = TextBlock | ToolUseBlock | ToolResultBlock + +export class TextBlock { + readonly type = 'textBlock' as const + readonly text: string + constructor(data: { text: string }) { + this.text = data.text + } +} + +export class ToolUseBlock { + readonly type = 'toolUseBlock' as const + readonly name: string + readonly toolUseId: string + readonly input: JSONValue + constructor(data: { name: string; toolUseId: string; input: JSONValue }) { + this.name = data.name + this.toolUseId = data.toolUseId + this.input = data.input + } +} + +export class ToolResultBlock { + readonly type = 'toolResultBlock' as const + readonly toolUseId: string + readonly status: 'success' | 'error' + readonly content: ToolResultContent[] + constructor(data: { toolUseId: string; status: 'success' | 'error'; content: ToolResultContent[] }) { + this.toolUseId = data.toolUseId + this.status = data.status + this.content = data.content + } +} + +// Wrong - Dependencies before top-level +export type Role = 'user' | 'assistant' + +export interface TextBlockData { + text: string +} + +export interface Message { + // Top-level should come first + role: Role + content: ContentBlock[] +} +``` + +**Rationale**: This ordering makes files more readable by providing an overview first, then details. + +### Discriminated Union Naming Convention + +**When creating discriminated unions with a `type` field, the type value MUST match the interface name with the first letter lowercase.** + +```typescript +// Correct - type matches class name (first letter lowercase) +export class TextBlock { + readonly type = 'textBlock' as const // Matches 'TextBlock' class name + readonly text: string + constructor(data: { text: string }) { + this.text = data.text + } +} + +export class CachePointBlock { + readonly type = 'cachePointBlock' as const // Matches 'CachePointBlock' class name + readonly cacheType: 'default' + constructor(data: { cacheType: 'default' }) { + this.cacheType = data.cacheType + } +} + +export type ContentBlock = TextBlock | ToolUseBlock | CachePointBlock + +// Wrong - type doesn't match class name +export class CachePointBlock { + readonly type = 'cachePoint' as const // Should be 'cachePointBlock' + readonly cacheType: 'default' +} +``` + +**Rationale**: This consistent naming makes discriminated unions predictable and improves code readability. Developers can easily understand the relationship between the type value and the class. + +### API Union Types (Bedrock Pattern) + +When the upstream API (e.g., Bedrock) defines a type as a **UNION** ("only one member can be specified"), model it as a TypeScript `type` union with each variant's field **required** — not an `interface` with optional fields. This allows non-breaking expansion when new variants are added. + +The Bedrock API marks all fields in union types as "Not Required" as a mechanism for future extensibility. In TypeScript, encode the mutual exclusivity using `|` with each variant having its field required. The "not required" from the API docs means "this field won't be present if a different variant is active." + +```typescript +// Correct: type union — each variant has its field required +// Adding a new variant later (e.g., | { image: ImageData }) is non-breaking +export type CitationSourceContent = { text: string } + +// Correct: multi-variant union with object-key discrimination +export type DocumentSourceData = + | { bytes: Uint8Array } + | { text: string } + | { content: DocumentContentBlockData[] } + | { location: S3LocationData } + +// Correct: multi-variant union for citation locations +export type CitationLocation = + | { documentChar: DocumentCharLocation } + | { documentPage: DocumentPageLocation } + | { web: WebLocation } + +// Wrong: interface with optional fields — cannot expand without breaking +export interface CitationSourceContent { + text?: string +} + +// Wrong: interface with required field — changing to union later is breaking +export interface CitationSourceContent { + text: string +} +``` + +**Key points**: + +- Use `type` alias (not `interface`) so it can be expanded to a union later +- Each variant's field is **required** within that variant +- Use object-key discrimination (`'text' in source`) to narrow variants at runtime +- See `DocumentSourceData` in `strands-ts/src/types/media.ts` and `CitationLocation` in `strands-ts/src/types/citations.ts` for reference implementations + +### Error Handling + +```typescript +// Good: Explicit error handling +export function process(input: string): string { + if (!input) { + throw new Error('Input cannot be empty') + } + return input.trim() +} + +// Good: Custom error types +export class ValidationError extends Error { + constructor(message: string) { + super(message) + this.name = 'ValidationError' + } +} +``` + +**Key Features:** + +- Automatic tool discovery and registration +- Lazy connection (connects on first use) +- Supports stdio and HTTP transports +- Resource cleanup with `Symbol.dispose` + +**See [`examples/mcp/`](strands-ts/examples/mcp/) for complete working examples.** + +### Test Assertions + +When asserting on objects, prefer `toStrictEqual` for full object comparison rather than checking individual fields: + +```typescript +// Good: Full object assertion with toStrictEqual +expect(provider.getConfig()).toStrictEqual({ + modelId: 'gemini-2.5-flash', + params: { temperature: 0.5 }, +}) + +// Bad: Checking individual fields +expect(provider.getConfig().modelId).toBe('gemini-2.5-flash') +expect(provider.getConfig().params.temperature).toBe(0.5) +``` + +**Rationale**: Full object assertions catch unexpected properties and ensure the complete shape is correct. + +### Dependency Management + +When adding or modifying dependencies, you **MUST** follow the guidelines in [dev-docs/DEPENDENCIES.md](dev-docs/DEPENDENCIES.md). Key points: + +- **`dependencies`**: Core SDK functionality that users don't interact with directly +- **`peerDependencies`**: Dependencies that cross API boundaries (users construct/pass instances) +- **`devDependencies`**: Build tools, testing frameworks, linters - not shipped to users + +**Rule**: If a dependency crosses an API boundary, it **MUST** be a peer dependency. + +## Things to Do + +**Do**: + +- Use relative imports for internal modules +- Co-locate unit tests with source under `__tests__` directories +- Follow nested describe pattern for test organization +- Write explicit return types for all functions +- Document all exported functions with TSDoc +- Use meaningful variable and function names +- Keep functions small and focused (single responsibility) +- Use async/await for asynchronous operations +- Handle errors explicitly + +## Things NOT to Do + +**Don't**: + +- Use `any` type (enforced by ESLint) +- Put unit tests in separate `tests/` directory (use `strands-ts/src/**/__tests__/**`) +- Skip documentation for exported functions +- Use semicolons (Prettier will remove them) +- Commit without running pre-commit hooks +- Ignore linting errors +- Skip type checking +- Use implicit return types + +## Development Commands + +For detailed command usage, see [CONTRIBUTING.md - Testing Instructions](CONTRIBUTING.md#testing-instructions-and-best-practices). + +Quick reference: + +```bash +npm test # Run unit tests in Node.js +npm run test:browser # Run unit tests in browser (Chromium via Playwright) +npm run test:all # Run all tests in all environments +npm run test:integ # Run integration tests +npm run test:coverage # Run tests with coverage report +npm run lint # Check code quality +npm run format # Auto-fix formatting +npm run type-check # Verify TypeScript types +npm run build # Compile TypeScript +``` + +## Troubleshooting Common Issues + +If TypeScript compilation fails: + +1. Run `npm run type-check` to see all type errors +2. Ensure all functions have explicit return types +3. Verify no `any` types are used +4. Check that all imports are correctly typed + +## Agent-Specific Notes + +### When Implementing Features + +1. **Read task requirements** carefully from the GitHub issue +2. **Follow TDD approach** if appropriate: + - Write failing tests first + - Implement minimal code to pass tests + - Refactor while keeping tests green +3. **Use existing patterns** as reference +4. **Document as you go** with TSDoc comments +5. **Run all checks** before committing (pre-commit hooks will enforce this) + +### Writing code + +- YOU MUST make the SMALLEST reasonable changes to achieve the desired outcome. +- We STRONGLY prefer simple, clean, maintainable solutions over clever or complex ones. Readability and maintainability are PRIMARY CONCERNS, even at the cost of conciseness or performance. +- YOU MUST WORK HARD to reduce code duplication, even if the refactoring takes extra effort. +- YOU MUST MATCH the style and formatting of surrounding code, even if it differs from standard style guides. Consistency within a file trumps external standards. +- YOU MUST NOT manually change whitespace that does not affect execution or output. Otherwise, use a formatting tool. +- Fix broken things immediately when you find them. Don't ask permission to fix bugs. + +#### Code Comments + +- NEVER add comments explaining that something is "improved", "better", "new", "enhanced", or referencing what it used to be +- Comments should explain WHAT the code does or WHY it exists, not how it's better than something else +- YOU MUST NEVER add comments about what used to be there or how something has changed. +- YOU MUST NEVER refer to temporal context in comments (like "recently refactored" "moved") or code. Comments should be evergreen and describe the code as it is. +- YOU MUST NEVER write overly verbose comments. Use concise language. + +### Code Review Considerations + +When responding to PR feedback: + +- Address all review comments +- Test changes thoroughly +- Update documentation if behavior changes +- Maintain test coverage +- Follow conventional commit format for fix commits + +### Integration with Other Files + +- **CONTRIBUTING.md**: Contains testing/setup commands and human contribution guidelines +- **dev-docs/TESTING.md**: Comprehensive testing guidelines (MUST follow when writing tests) +- **dev-docs/PR.md**: Pull request guidelines and template +- **README.md**: Public-facing documentation, links to strandsagents.com +- **package.json**: Root workspace config that delegates to strands-ts +- **strands-ts/package.json**: SDK package config, dependencies, and npm scripts + +## Additional Resources + +- [TypeScript Handbook](https://www.typescriptlang.org/docs/handbook/intro.html) +- [Vitest Documentation](https://vitest.dev/) +- [TSDoc Reference](https://tsdoc.org/) +- [Conventional Commits](https://www.conventionalcommits.org/) +- [Strands Agents Documentation](https://strandsagents.com/) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000..5b627cfa60 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,4 @@ +## Code of Conduct +This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). +For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact +opensource-codeofconduct@amazon.com with any additional questions or comments. diff --git a/COMPATIBILITY.MD b/COMPATIBILITY.MD new file mode 100644 index 0000000000..a38a77a3f2 --- /dev/null +++ b/COMPATIBILITY.MD @@ -0,0 +1,93 @@ +# Compatibility Policy + +This document outlines the Strands TypeScript SDK's policy on changes that are **not considered breaking changes** under semantic versioning. Understanding these policies helps you anticipate how the SDK may evolve without requiring major version bumps. + +## Field to Getter/Setter Conversion + +Converting a public mutable field to a property with getter/setter methods **is not considered a breaking change**, even when adding validation or side effects. + +### Policy + +The SDK may convert public mutable fields to getter/setter properties in minor or patch releases. This includes adding: +- Validation logic that throws errors for invalid values +- Side effects during assignment (logging, notifications, state updates) +- Computed or transformed values in getters + +### Rationale + +In TypeScript and JavaScript, getter/setter properties are syntactically and behaviorally identical to direct field access from the consumer's perspective: + +```typescript +// Before: Direct field access +agent.model = newModel +const currentModel = agent.model + +// After: Getter/setter (identical usage) +agent.model = newModel // Calls setter +const currentModel = agent.model // Calls getter +``` + +Consumers cannot distinguish between direct field access and property access at the call site. The implementation change is transparent to user code. + +### Example + +The `Agent.model` property is currently a public mutable field. In a future release, it may be converted to a getter/setter to add validation: + +```typescript +// Current implementation (field) +public model: Model + +// Possible future implementation (getter/setter with validation) +private _model: Model +public get model(): Model { + return this._model +} +public set model(value: Model) { + if (!value) { + throw new Error('Model cannot be null or undefined') + } + this._model = value +} +``` + +User code remains unchanged and continues to work as before. + +## Union Type Extensions + +Adding new types or classes to union types **is not considered a breaking change**, unless the union explicitly declares that it will no longer change. + +### Policy + +The SDK may add new event types, result variants, or other union members in minor or patch releases. This includes: +- New event types in streaming results +- Additional error types in result unions +- New configuration options in config unions +- Extended enum-like union types + +### Rationale + +Union type extensions are additive changes that don't break existing code. + +Consumers handle union types through type guards, switch statements, or pattern matching that focus on known variants. + +New union members are simply ignored by existing logic. + +### Example + +The `AgentStreamEvent` type returned by `Agent.stream()` may receive new event types: + +```typescript +// Current usage (continues to work) +for await (const event of agent.stream('Hello')) { + if (event.type === 'textDelta') { + console.log(event.text) + } + // New event types are ignored by existing code +} +``` + +New event types added to the union don't affect existing event handling logic. + +## Feedback + +If you have questions or concerns about this compatibility policy, please [open an issue](https://github.com/strands-agents/sdk-typescript/issues) on GitHub. diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 19dc35b243..0000000000 --- a/LICENSE +++ /dev/null @@ -1,175 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. \ No newline at end of file diff --git a/site/LICENSE b/LICENSE.APACHE similarity index 100% rename from site/LICENSE rename to LICENSE.APACHE diff --git a/LICENSE.MIT b/LICENSE.MIT new file mode 100644 index 0000000000..5d2ae23ae8 --- /dev/null +++ b/LICENSE.MIT @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Strands Agents Contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/dev-docs/DEPENDENCIES.md b/dev-docs/DEPENDENCIES.md new file mode 100644 index 0000000000..2dfe94015d --- /dev/null +++ b/dev-docs/DEPENDENCIES.md @@ -0,0 +1,64 @@ +# Dependency Guidelines - Strands TypeScript SDK + +> **IMPORTANT**: When adding or modifying dependencies, you **MUST** follow the guidelines in this document. These patterns ensure proper dependency resolution for SDK consumers and avoid version conflicts. + +| Category | When to Use | +| ---------------------- | ----------------------------------------------------------------------- | +| `dependencies` | Core SDK functionality that users don't interact with directly | +| `peerDependencies` | Dependencies that cross API boundaries (users construct/pass instances) | +| `devDependencies` | Build tools, testing frameworks, linters - not shipped to users | +| `peerDependenciesMeta` | Mark peer dependencies as optional when not all users need them | + +## Peer Dependencies + +Peer dependencies are packages the consuming application provides. The SDK relies on the user's installed version, ensuring both operate on the same instance and avoiding version conflicts. + +**Rule**: If a dependency crosses an API boundary, it **MUST** be a peer dependency. + +**Example**: `zod` is a peer dependency because users construct Zod schemas and pass them to the SDK: + +```typescript +import { z } from 'zod' +import { Agent, tool } from '@strands-agents/sdk' + +const calculator = tool({ + name: 'calculator', + inputSchema: z.object({ value: z.number() }), + callback: (input) => input.value * 2, +}) + +const agent = new Agent({ model, tools: [calculator] }) +``` + +Mark peer dependencies as **optional** when not all users need them (e.g., model provider SDKs). Optional peer dependencies must also be added to `devDependencies` for SDK development and testing. + +## Package Lock File + +The `package-lock.json` file ensures reproducible builds by locking exact dependency versions. + +| Command | When to Use | +|---------|-------------| +| `npm ci` | Installing dependencies without changes (fresh clone, after pulling, CI pipelines) | +| `npm install` | Adding, removing, or updating dependencies | + +`npm ci` installs exactly what's in the lock file without modifying it, failing if there's a mismatch. This prevents accidental lock file changes. + +**When to modify:** + +- Adding, removing, or updating dependencies in `package.json` +- Running `npm audit fix` to patch security vulnerabilities + +After modifying dependencies, regenerate the lock file for all platforms: + +```bash +npm run lock:refresh +``` + +This generates a lock file that includes platform-specific optional dependencies for Linux, macOS, and Windows (both x64 and arm64), ensuring `npm ci` works in CI regardless of where the lock file was generated. + +**Rules:** + +1. Never manually edit `package-lock.json` - always use `npm install` or `npm update` +2. Always run `npm run lock:refresh` after modifying dependencies to ensure cross-platform compatibility +3. Commit `package-lock.json` changes in the same commit as the corresponding `package.json` changes +4. If `package-lock.json` has merge conflicts, delete it and run `npm run lock:refresh` to regenerate diff --git a/dev-docs/DIVERGENCES.md b/dev-docs/DIVERGENCES.md new file mode 100644 index 0000000000..4ee9d00841 --- /dev/null +++ b/dev-docs/DIVERGENCES.md @@ -0,0 +1,59 @@ +# Divergences + +Two lists. + +- **Proposed TS SDK changes** — places where the TS surface should move toward + what the WIT contract already says. Open proposals, not landed changes. +- **`strands-py-wasm` vs upstream `sdk-python`** — deliberate differences in our + Python package, given the agent loop runs in WASM here. + +--- + +## Proposed TS SDK changes + +- `StopData` → `StopEvent`. Every other terminal arm in the stream is `*Event` (`MetadataEvent`, `LifecycleEvent`). WIT is already `stop-event`. +- Merge `StreamEvent` + `AgentStreamEvent`. The abstract `StreamEvent` base class has no runtime behavior worth keeping; make `StreamEvent` the union directly. +- `InterruptResponseContent` → `InterruptResponseBlock`. Every other content block class ends in `Block` (`TextBlock`, `ImageBlock`, `ReasoningBlock`). WIT is `interrupt-response-block`. +- Split multi-agent `Status` into `OrchestrationStatus` (in-flight arms) and `TerminalStatus` (final arms). WIT has the split; TS conflated them. +- Consider dropping `*Event` suffix on hook event classes where it reads as "event-event" (`ModelStreamUpdateEvent`, `ContentBlockEvent`). Low priority. +- `SaveLatestStrategy` should become a variant, not a string union. Currently `'message' | 'invocation' | 'trigger'` with the trigger callback registered separately; WIT carries the handler id inline on `trigger(string)`. TS should follow: `{ tag: 'trigger'; handlerId: string } | 'message' | 'invocation'`. +- `Usage.totalTokens` is redundant with `inputTokens + outputTokens`. Either drop, or doc it as the canonical provider-reported value that may diverge. Same issue exists in WIT; decide once and apply both places. +- `Message.metadata.custom` is `Record` in TS, opaque JSON string in WIT. Type WIT (tracks with the typed-snapshot work) or drop the TS structure. Don't leave them mismatched. +- Group `name`, `id`, `description` into an `agentIdentity` sub-record on `AgentConfig`. WIT already nests them; TS has three top-level fields. +- `ToolResultBlock.status` should be an enum, not `'success' | 'error'` string union. Matches WIT `tool-result-status`. +- `ToolResultBlock.content` should be a typed discriminated union (matching WIT `tool-result-content`), not a raw union reduced at runtime. +- `CachePointBlock.type: 'default'` should be a `CacheKind` enum. Matches WIT `cache-kind`. +- `ConversationManagerConfig` should be a typed variant (`none` | `slidingWindow` | `summarizing`), not a flat record with a string `strategy` discriminator. The wasmtime-py limitation that motivated the flat shape has been lifted. +- `TraceContext` should be a typed record (`traceparent`, `tracestate?`), not JSON W3C headers. +- `CustomModelConfig.stateful` should be a static field on the config, not a per-invocation `Model.stateful` getter. Stateful providers are identified once at registration. +- MCP transport names: if TS still uses a generic `http` transport, rename to `streamableHttp` per the current MCP spec. WIT is `streamable-http`. +- Consider replacing millisecond `number` fields (retries, graph/swarm timeouts, MCP task polling) with a `Duration` type. WIT uses `wasi:clocks/monotonic-clock.duration`. +- Add a `StrandsError extends Error` base class and reparent every SDK-thrown error to it. Today `ModelError` extends `Error`; `JsonValidationError`, `ToolValidationError`, `StructuredOutputError`, `ConcurrentInvocationError`, `SessionError` extend `Error` directly with no shared root. Users can't `catch (e instanceof StrandsError)`. +- Export `SessionError` and `ProviderTokenCountError` from `src/index.ts`. Both exist in `src/errors.ts` but neither is in the package index or any subpath. +- Surface hook event errors as typed objects rather than strings. `AfterModelCallEvent.error` should carry a `ModelError` (or an `unknown` widened to the `ModelError` union); same for `AfterToolCallEvent.error` → `ToolError`. WIT already carries typed errors into the hook payloads. +- `Plugin` should carry a doc banner that it's TS-only. Host-side SDKs can't implement it without a `custom-plugin` WIT interface that doesn't exist yet; users reading the type today have no way to know. +- Adopt the jco-generated types verbatim for `Usage`, `Metrics`, and `StopReason`. Today the WASM bridge maps camelCase TS types to kebab-case WIT variants on every event; having the SDK's runtime types *be* the generated types deletes those translators. +- `Agent.stream()` should yield the WIT `stream-event` variant directly, not the current class hierarchy of `ModelStreamUpdateEvent` / `ContentBlockEvent` / `ToolResultEvent` / etc. The hook-registration API (`agent.addHook(BeforeModelCallEvent, cb)`) stays; only the stream payload changes shape. Collapses ~300 lines of guest-to-host translation. +- SDK constructors for `Agent`, `SessionManager`, `ConversationManager`, tool registry, and built-in model providers (Bedrock/Anthropic/OpenAI/Google) should accept the WIT-shaped config records (`agent-config`, `session-config`, `conversation-manager-config`, `tool-spec`, `model-config` arms) directly. Eliminates the `buildSystemPrompt` / `createSessionManager` / `createConversationManager` / `createToolChoiceProxy` / `createTools` / `createModel` translators in the WASM bridge. +- `Agent` should implement the WIT `api` interface directly (or componentize-js should generate the resource shims). Removes `AgentImpl` / `ResponseStreamImpl` from the bridge. Enables the final test: running componentize-js on the SDK with no `entry.ts` produces an equivalent `.wasm`. + +--- + +## `strands-py-wasm` vs upstream `sdk-python` + +- `types/content.py` + `types/tools.py` → all generated types live in `_generated.py` (one file, auto-written from the WIT bindings). +- `agent/conversation_manager/*.py` → `SlidingWindowConversationManager` / `SummarizingConversationManager` are config-only dataclasses; execution is in the WASM guest. +- `session/file_session_manager.py` + `s3_session_manager.py` → `FileStorage` / `S3Storage` are config-only passthroughs to the WIT `storage_config`. +- `telemetry/`, `plugins/`, `hooks/`, `handlers/`, `event_loop/` — agent loop runs in the guest; these modules are either removed or collapsed into the thin SDK surface. +- Users never see base64 — binary content arrives as `bytes`. +- Users never format or parse ISO-8601 — snapshot timestamps come through as `wasi:clocks/wall-clock.datetime` records. +- Two error surfaces: the WIT tagged-variant records (`StorageError`, `ModelError`, `ToolError`, etc.) for pattern matching, plus `Exception` classes (`StrandsError`, `ContextWindowOverflowError`, `ToolValidationError`, …) for raise/catch. No stringly-typed error payloads. +- No custom `FileSessionManager` / `S3SessionManager` classes. Users pass `FileStorage(base_dir=...)` or `S3Storage(bucket=...)`; the guest instantiates the backend. +- Custom storage: set `StorageConfig_Custom(backend_id=...)` and implement the `snapshot-storage` host interface. No extra config record needed. +- `save_latest_policy.trigger` holds the handler id inline. Upstream's optional trigger-callback-on-config field is gone. +- `Graph`, `Swarm`, and `McpClient` are config-builder subclasses of the generated WIT records, not host-side orchestration runtimes. Orchestration and MCP transport management run in the guest. +- Interrupts are stream events, not exceptions. Upstream raises `InterruptException` from hooks and aggregates them in the registry; strands-py-wasm emits `StreamEventInterrupt(value=Interrupt)` on the event stream and resumes via `agent.respond(interrupt_id, payload)`. The `HookRegistry` does not interpret or aggregate interrupts. +- `HookRegistry` has no `order=` / `HookOrder` knobs; LIFO dispatch for `After*` arms is inferred from the class name. Upstream has a `should_reverse_callbacks` property on each event; our inference replaces the hand-set property. +- No type-hint inference on `add_callback`. Users pass `event_type` explicitly. Upstream's `add_callback(None, fn)` auto-inference added a `_type_inference.py` module we consider more trouble than it's worth. +- No `BaseHookEvent.__setattr__` immutability gate. Our hook events come from the WIT generator as `@dataclass` records; if immutability becomes required we'll add `frozen=True` at the generator level for both wire and hook consumers. +- New Python-layer types that aren't in upstream: `PydanticTool` (analog to TS's `ZodTool`), `McpClient` + `StdioMcpTransport` / `StreamableHttpMcpTransport` / `SseMcpTransport` (subclass config builders over the generated WIT records), `AgentResult` (matches the TS SDK class for the terminal invocation value). diff --git a/dev-docs/PR.md b/dev-docs/PR.md new file mode 100644 index 0000000000..d92a2a1172 --- /dev/null +++ b/dev-docs/PR.md @@ -0,0 +1,200 @@ +# Pull Request Description Guidelines + +Good PR descriptions help reviewers understand the context and impact of your changes. They enable faster reviews, better decision-making, and serve as valuable historical documentation. + +When creating a PR, follow the [GitHub PR template](../.github/PULL_REQUEST_TEMPLATE.md) and use these guidelines to fill it out effectively. + +## Who's Reading Your PR? + +Write for senior engineers familiar with the SDK. Assume your reader: + +- Understands the SDK's architecture and patterns +- Has context about the broader system +- Can read code diffs to understand implementation details +- Values concise, focused communication + +## What to Include + +Every PR description should have: + +1. **Motivation** — Why is this change needed? +2. **Public API Changes** — What changes to the public API (with code snippets)? +3. **Use Cases** (optional) — When would developers use this feature? Only include for non-obvious functionality; skip for trivial changes or obvious fixes. +4. **Breaking Changes** (if applicable) — What breaks and how to migrate? + +## Writing Principles + +**Focus on WHY, not HOW:** + +- ✅ "The OpenAI SDK supports dynamic API keys, but we don't expose this capability" +- ❌ "Added ApiKeySetter type import from openai/client" + +**Document public API changes with example code snippets:** + +- ✅ Show before/after code snippets for API changes +- ❌ List every file or line changed + +**Be concise:** + +- ✅ Use prose over bullet lists when possible +- ❌ Create exhaustive implementation checklists + +**Emphasize user impact:** + +- ✅ "Enables secret manager integration for credential rotation" +- ❌ "Updated error message to mention 'string or function'" + +## What to Skip + +Leave these out of your PR description: + +- **Implementation details** — Code comments and commit messages cover this +- **Test coverage notes** — CI will catch issues; assume tests are comprehensive +- **Line-by-line change lists** — The diff provides this +- **Build/lint/coverage status** — CI handles verification +- **Commit hashes** — GitHub links commits automatically + +## Anti-patterns + +❌ **Over-detailed checklists:** + +```markdown +### Type Definition Updates + +- Added ApiKeySetter type import from 'openai/client' +- Updated OpenAIModelOptions interface apiKey type +``` + +❌ **Implementation notes reviewers don't need:** + +```markdown +## Implementation Notes + +- No breaking changes - all existing string-based usage continues to work +- OpenAI SDK handles validation of function return values +``` + +❌ **Test coverage bullets:** + +```markdown +### Test Coverage + +- Added test: accepts function-based API key +- Added test: accepts async function-based API key +``` + +## Good Examples + +✅ **Motivation section:** + +```markdown +## Motivation + +The OpenAI SDK supports dynamic API key resolution through async functions, +enabling use cases like credential rotation and secret manager integration. +However, our SDK currently only accepts static strings for the apiKey parameter, +preventing users from leveraging these capabilities. +``` + +✅ **Public API Changes section:** + +````markdown +## Public API Changes + +The `OpenAIModelOptions.apiKey` parameter now accepts either a string or an +async function: + +```typescript +// Before: only string supported +const model = new OpenAIModel({ + modelId: 'gpt-4o', + apiKey: 'sk-...', +}) + +// After: function also supported +const model = new OpenAIModel({ + modelId: 'gpt-4o', + apiKey: async () => await secretManager.getApiKey(), +}) +``` + +The change is backward compatible—all existing string-based usage continues +to work without modification. + +```` + +✅ **Use Cases section:** +```markdown +## Use Cases + +- **API key rotation**: Rotate keys without application restart +- **Secret manager integration**: Fetch credentials from AWS Secrets Manager, Vault, etc. +- **Multi-tenant systems**: Dynamically select API keys based on context +```` + +## Template + +````markdown +## Motivation + +[Explain WHY this change is needed. What problem does it solve? What limitation +does it address? What user need does it fulfill?] + +Resolves: #[issue-number] + +## Public API Changes + +[Document changes to public APIs with before/after code snippets. If no public +API changes, state "No public API changes."] + +```typescript +// Before +[existing API usage] + +// After +[new API usage] +``` + +[Explain behavior, parameters, return values, and backward compatibility.] + +## Use Cases (optional) + +[Only include for non-obvious functionality. Provide 1-3 concrete use cases +showing when developers would use this feature. Skip for trivial changes obvious fixes..] + +## Breaking Changes (if applicable) + +[If this is a breaking change, explain what breaks and provide migration guidance.] + +### Migration + +```typescript +// Before +[old code] + +// After +[new code] +``` + +```` + +## Why These Guidelines? + +**Focus on WHY over HOW** because code diffs show implementation details, commit messages document granular changes, and PR descriptions provide the broader context reviewers need. + +**Skip test/lint/coverage details** because CI pipelines verify these automatically. Including them adds noise without value. + +**Write for senior engineers** to enable concise, technical communication without redundant explanations. + +## References + +- [Conventional Commits](https://www.conventionalcommits.org/) +- [Google's Code Review Guidelines](https://google.github.io/eng-practices/review/) + +## Checklist Items + + - [ ] Does the PR description target a Senior Engineer familiar with the project? + - [ ] Does the PR description give an overview of the feature being implemented, including any notes on key implemention decisions + - [ ] Does the PR include a "Resolves #" in the body and is not bolded? + - [ ] Does the PR contain the motivation or use-cases behind the change? + - [ ] Does the PR omit irrelevent details not needed for historical reference? diff --git a/dev-docs/TESTING.md b/dev-docs/TESTING.md new file mode 100644 index 0000000000..40251b948a --- /dev/null +++ b/dev-docs/TESTING.md @@ -0,0 +1,738 @@ +# Testing Guidelines - Strands TypeScript SDK + +> **IMPORTANT**: When writing tests, you **MUST** follow the guidelines in this document. These patterns ensure consistency, maintainability, and proper test coverage across the SDK. + +This document contains comprehensive testing guidelines for the Strands TypeScript SDK. For general development guidance, see [AGENTS.md](../AGENTS.md). + +## Test Fixtures Quick Reference + +All test fixtures are located in `src/__fixtures__/`. Use these helpers to reduce boilerplate and ensure consistency. + +| Fixture | File | When to Use | Details | +| ---------------------- | ----------------------- | ------------------------------------------------------------------------------------ | --------------------------------------------------------------------------- | +| `MockMessageModel` | `mock-message-model.ts` | Agent loop tests - specify content blocks, auto-generates stream events | [Model Fixtures](#model-fixtures-mock-message-modelts-model-test-helpersts) | +| `TestModelProvider` | `model-test-helpers.ts` | Low-level model tests - precise control over individual `ModelStreamEvent` sequences | [Model Fixtures](#model-fixtures-mock-message-modelts-model-test-helpersts) | +| `collectIterator()` | `model-test-helpers.ts` | Collect all items from any async iterable into an array | [Model Fixtures](#model-fixtures-mock-message-modelts-model-test-helpersts) | +| `collectGenerator()` | `model-test-helpers.ts` | Collect yielded items AND final return value from async generators | [Model Fixtures](#model-fixtures-mock-message-modelts-model-test-helpersts) | +| `MockHookProvider` | `mock-hook-provider.ts` | Record and verify hook invocations during agent execution | [Hook Fixtures](#hook-fixtures-mock-hook-providerts) | +| `createMockTool()` | `tool-helpers.ts` | Create mock tools with custom result behavior | [Tool Fixtures](#tool-fixtures-tool-helpersts) | +| `createRandomTool()` | `tool-helpers.ts` | Create minimal mock tools when execution doesn't matter | [Tool Fixtures](#tool-fixtures-tool-helpersts) | +| `createMockContext()` | `tool-helpers.ts` | Create mock `ToolContext` for testing tool implementations directly | [Tool Fixtures](#tool-fixtures-tool-helpersts) | +| `createMockAgent()` | `agent-helpers.ts` | Create minimal mock Agent with messages and state | [Agent Fixtures](#agent-fixtures-agent-helpersts) | +| `expectAgentResult()` | `agent-helpers.ts` | Assert on `AgentResult` with expected stop reason, message text, cycle count, and traces | [Agent Fixtures](#agent-fixtures-agent-helpersts) | +| `createCancellableAgent()` | `agent-helpers.ts` | Create a minimal `InvokableAgent` that sleeps for a configurable delay and aborts early when its `cancelSignal` fires — used for timeout/cancellation tests | [Agent Fixtures](#agent-fixtures-agent-helpersts) | +| `isNode` / `isBrowser` | `environment.ts` | Environment detection for conditional test execution | [Environment Fixtures](#environment-fixtures-environmentts) | +| `MockSpan` | `mock-span.ts` | Mock OTEL Span that records all setAttribute/addEvent/end calls for assertion | [Telemetry Fixtures](#telemetry-fixtures-mock-spants-mock-meterts) | +| `eventAttr()` | `mock-span.ts` | Extract a string attribute from a mock span event | [Telemetry Fixtures](#telemetry-fixtures-mock-spants-mock-meterts) | +| `MockMeter` | `mock-meter.ts` | Mock OTEL Meter that records all counter/histogram instrument calls for assertion | [Telemetry Fixtures](#telemetry-fixtures-mock-spants-mock-meterts) | +| `expectLoopMetrics()` | `metrics-helpers.ts` | Assert on `AgentMetrics` with expected cycle count, tool names, and optional token usage | [Metrics Fixtures](#metrics-fixtures-metrics-helpersts) | +| `findMetricValue()` | `metrics-helpers.ts` | Find the latest data point value for a named OTEL metric from ResourceMetrics | [Metrics Fixtures](#metrics-fixtures-metrics-helpersts) | + +## Test Organization + +### Unit Test Location + +**Rule**: Unit test files are co-located with source files, grouped in a directory named `__tests__` + +``` +src/subdir/ +├── agent.ts # Source file +├── model.ts # Source file +└── __tests__/ + ├── agent.test.ts # Tests for agent.ts + └── model.test.ts # Tests for model.ts +``` + +### Integration Test Location + +**Rule**: Integration tests are separate in `tests_integ/` + +``` +tests_integ/ +├── api.test.ts # Tests public API +└── environment.test.ts # Tests environment compatibility +``` + +### Test File Naming + +**File naming determines which environment(s) tests run in:** + +- `*.test.ts` — runs in **both** Node.js and browser environments +- `*.test.node.ts` — runs **only** in Node.js environment +- `*.test.browser.ts` — runs **only** in browser environment + +This naming convention applies to both unit tests (`src/**/__tests__/`) and integration tests (`test/integ/`). + +**Examples:** + +``` +src/module/__tests__/ +├── module.test.ts # Runs in Node.js AND browser +├── module.test.node.ts # Runs in Node.js only +└── module.test.browser.ts # Runs in browser only +``` + +Use environment-specific test files when tests depend on platform-specific features like filesystem access, environment variables, or browser APIs. + +## Test Structure Pattern + +Follow this nested describe pattern for consistency: + +### For Functions + +```typescript +import { describe, it, expect } from 'vitest' +import { functionName } from '../module' + +describe('functionName', () => { + describe('when called with valid input', () => { + it('returns expected result', () => { + const result = functionName('input') + expect(result).toBe('expected') + }) + }) + + describe('when called with edge case', () => { + it('handles gracefully', () => { + const result = functionName('') + expect(result).toBeDefined() + }) + }) +}) +``` + +### For Classes + +```typescript +import { describe, it, expect } from 'vitest' +import { ClassName } from '../module' + +describe('ClassName', () => { + describe('methodName', () => { + it('returns expected result', () => { + const instance = new ClassName() + const result = instance.methodName() + expect(result).toBe('expected') + }) + + it('handles error case', () => { + const instance = new ClassName() + expect(() => instance.methodName()).toThrow() + }) + }) + + describe('anotherMethod', () => { + it('performs expected action', () => { + // Test implementation + }) + }) +}) +``` + +### Key Principles + +- Top-level `describe` uses the function/class name +- Nested `describe` blocks group related test scenarios +- Use descriptive test names without "should" prefix +- Group tests by functionality or scenario + +## Writing Effective Tests + +```typescript +// Good: Clear, specific test +describe('calculateTotal', () => { + describe('when given valid numbers', () => { + it('returns the sum', () => { + expect(calculateTotal([1, 2, 3])).toBe(6) + }) + }) + + describe('when given empty array', () => { + it('returns zero', () => { + expect(calculateTotal([])).toBe(0) + }) + }) +}) + +// Bad: Vague, unclear test +describe('calculateTotal', () => { + it('works', () => { + expect(calculateTotal([1, 2, 3])).toBeTruthy() + }) +}) +``` + +## Test Batching Strategy + +**Rule**: When test setup cost exceeds test logic cost, you MUST batch related assertions into a single test. + +**You MUST batch when**: + +- Setup complexity > test logic complexity +- Multiple assertions verify the same object state +- Related behaviors share expensive context + +**You SHOULD keep separate tests for**: + +- Distinct behaviors or execution paths +- Error conditions +- Different input scenarios + +**Bad - Redundant setup**: + +```typescript +it('has correct tool name', () => { + const tool = createComplexTool({ + /* expensive setup */ + }) + expect(tool.toolName).toBe('testTool') +}) + +it('has correct description', () => { + const tool = createComplexTool({ + /* same expensive setup */ + }) + expect(tool.description).toBe('Test description') +}) +``` + +**Good - Batched properties**: + +```typescript +it('creates tool with correct properties', () => { + const tool = createComplexTool({ + /* setup once */ + }) + expect(tool.toolName).toBe('testTool') + expect(tool.description).toBe('Test description') + expect(tool.toolSpec.name).toBe('testTool') +}) +``` + +## Object Assertion Best Practices + +**Prefer testing entire objects at once** instead of individual properties for better readability and test coverage. + +```typescript +// ✅ Good: Verify entire object at once +it('returns expected user object', () => { + const user = getUser('123') + expect(user).toEqual({ + id: '123', + name: 'John Doe', + email: 'john@example.com', + isActive: true, + }) +}) + +// ✅ Good: Verify entire array of objects +it('yields expected stream events', async () => { + const events = await collectIterator(stream) + expect(events).toEqual([ + { type: 'streamEvent', data: 'Starting...' }, + { type: 'streamEvent', data: 'Processing...' }, + { type: 'streamEvent', data: 'Complete!' }, + ]) +}) + +// ❌ Bad: Testing individual properties +it('returns expected user object', () => { + const user = getUser('123') + expect(user).toBeDefined() + expect(user.id).toBe('123') + expect(user.name).toBe('John Doe') + expect(user.email).toBe('john@example.com') + expect(user.isActive).toBe(true) +}) + +// ❌ Bad: Testing array elements individually in a loop +it('yields expected stream events', async () => { + const events = await collectIterator(stream) + for (const event of events) { + expect(event.type).toBe('streamEvent') + expect(event).toHaveProperty('data') + } +}) +``` + +**Benefits of testing entire objects**: + +- **More concise**: Single assertion instead of multiple +- **Better test coverage**: Catches unexpected additional or missing properties +- **More readable**: Clear expectation of the entire structure +- **Easier to maintain**: Changes to the object require updating one place + +**Use cases**: + +- Always use `toEqual()` for object and array comparisons +- Use `toBe()` only for primitive values and reference equality +- When testing error objects, verify the entire structure including message and type + +## What to Test + +**Testing Approach:** + +- You **MUST** write tests for implementations (functions, classes, methods) +- You **SHOULD NOT** write tests for interfaces since TypeScript compiler already enforces type correctness +- You **SHOULD** write Vitest type tests (`*.test-d.ts`) for complex types to ensure backwards compatibility + +**Example Implementation Test:** + +```typescript +describe('BedrockModel', () => { + it('streams messages correctly', async () => { + const provider = new BedrockModel(config) + const stream = provider.stream(messages) + + for await (const event of stream) { + if (event.type === 'modelMessageStartEvent') { + expect(event.role).toBe('assistant') + } + } + }) +}) +``` + +## Test Coverage + +- **Minimum**: 80% coverage required (enforced by Vitest) +- **Target**: Aim for high coverage on critical paths +- **Exclusions**: Test files, config files, generated code + +## Test Model Providers + +**When to use each test provider:** + +- **`MockMessageModel`**: For agent loop tests and high-level flows - focused on content blocks +- **`TestModelProvider`**: For low-level event streaming tests where you need precise control over individual events + +### MockMessageModel - Content-Focused Testing + +For tests focused on messages, you SHOULD use `MockMessageModel` with a content-focused API that eliminates boilerplate: + +```typescript +import { MockMessageModel } from '../__fixtures__/mock-message-model' + +// ✅ RECOMMENDED - Single content block (most common) +const provider = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + +// ✅ RECOMMENDED - Array of content blocks +const provider = new MockMessageModel().addTurn([ + { type: 'textBlock', text: 'Let me help' }, + { type: 'toolUseBlock', name: 'calc', toolUseId: 'id-1', input: {} }, +]) + +// ✅ RECOMMENDED - Multi-turn with builder pattern +const provider = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'calc', toolUseId: 'id-1', input: {} }) // Auto-derives 'toolUse' + .addTurn({ type: 'textBlock', text: 'The answer is 42' }) // Auto-derives 'endTurn' + +// ✅ OPTIONAL - Explicit stopReason when needed +const provider = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Partial response' }, { stopReason: 'maxTokens' }) + +// ✅ OPTIONAL - Token usage metadata (emits modelMetadataEvent after message stop) +const provider = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'calc', toolUseId: 'id-1', input: {} }, { + usage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 }, + }) + .addTurn({ type: 'textBlock', text: 'Done' }, { + usage: { inputTokens: 200, outputTokens: 30, totalTokens: 230 }, + }) + +// ✅ OPTIONAL - Error handling +const provider = new MockMessageModel() + .addTurn({ type: 'textBlock', text: 'Success' }) + .addTurn(new Error('Model failed')) +``` + +## Testing Hooks + +When testing hook behavior, you **MUST** use `agent.hooks.addCallback()` for registering single callbacks when `agent.hooks` is available. Do NOT create inline `HookProvider` objects — this is an anti-pattern for single callbacks. + +```typescript +// ✅ CORRECT - Use agent.hooks.addCallback() for single callbacks +const agent = new Agent({ model, tools: [tool] }) + +agent.hooks.addCallback(BeforeToolCallEvent, (event: BeforeToolCallEvent) => { + event.toolUse = { + ...event.toolUse, + input: { value: 42 }, + } +}) + +// ✅ CORRECT - Use MockHookProvider to record and verify hook invocations +const hookProvider = new MockHookProvider() +const agent = new Agent({ model, hooks: [hookProvider] }) +await agent.invoke('Hi') +expect(hookProvider.invocations).toContainEqual(new BeforeInvocationEvent({ agent })) + +// ❌ WRONG - Do NOT create inline HookProvider objects +const switchToolHook = { + registerCallbacks: (registry: HookRegistry) => { + registry.addCallback(BeforeToolCallEvent, (event: BeforeToolCallEvent) => { + if (event.toolUse.name === 'tool1') { + event.tool = tool2 + } + }) + }, +} +``` + +**When to use each approach:** + +- **`agent.hooks.addCallback()`** - For adding a single callback to verify hook behavior (e.g., modifying tool input, switching tools) +- **`MockHookProvider`** - For recording and verifying hook lifecycle behavior and that specific hook events fired during execution + +## Test Fixtures Reference + +All test fixtures are located in `src/__fixtures__/`. Use these helpers to reduce boilerplate and ensure consistency. + +### Model Fixtures (`mock-message-model.ts`, `model-test-helpers.ts`) + +- **`MockMessageModel`** - Content-focused model for agent loop tests. Use `addTurn()` with content blocks. +- **`TestModelProvider`** - Low-level model for precise control over `ModelStreamEvent` sequences. +- **`collectIterator(stream)`** - Collects all items from an async iterable into an array. +- **`collectGenerator(generator)`** - Collects yielded items and final return value from an async generator. + +```typescript +// MockMessageModel for agent tests +const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'calc', toolUseId: 'id-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + +// collectIterator for stream results +const events = await collectIterator(agent.stream('Hi')) +``` + +### Hook Fixtures (`mock-hook-provider.ts`) + +- **`MockHookProvider`** - Records all hook invocations for verification. Pass to `Agent({ hooks: [provider] })`. + - Use `{ includeModelEvents: false }` to exclude model streaming and result events from recordings. + - Access `provider.invocations` to verify hook events fired. + +```typescript +// Record and verify hook invocations +const hookProvider = new MockHookProvider({ includeModelEvents: false }) +const agent = new Agent({ model, hooks: [hookProvider] }) + +await agent.invoke('Hi') + +expect(hookProvider.invocations[0]).toEqual(new BeforeInvocationEvent({ agent })) +``` + +### Tool Fixtures (`tool-helpers.ts`) + +- **`createMockTool(name, resultFn)`** - Creates a mock tool with custom result behavior. +- **`createRandomTool(name?)`** - Creates a minimal mock tool (use when tool execution doesn't matter). +- **`createMockContext(toolUse, agentState?)`** - Creates a mock `ToolContext` for testing tool implementations directly. + +```typescript +// Mock tool with custom result +const tool = createMockTool( + 'calculator', + () => new ToolResultBlock({ toolUseId: 'id', status: 'success', content: [new TextBlock('42')] }) +) + +// Minimal tool when execution doesn't matter +const tool = createRandomTool('myTool') +``` + +**When to use fixtures vs `FunctionTool` directly:** + +Use `createMockTool()` or `createRandomTool()` when tools are incidental to the test. Use `FunctionTool` or `tool()` directly only when testing tool-specific behavior. + +```typescript +// ✅ Use fixtures when testing agent/hook behavior +const tool = createMockTool('testTool', () => ({ + type: 'toolResultBlock', + toolUseId: 'tool-1', + status: 'success' as const, + content: [new TextBlock('Success')], +})) +const agent = new Agent({ model, tools: [tool] }) + +// ❌ Don't use FunctionTool when tool behavior is irrelevant to the test +const tool = new FunctionTool({ name: 'testTool', description: '...', inputSchema: {...}, callback: ... }) +``` + +### Agent Fixtures (`agent-helpers.ts`) + +- **`createMockAgent(data?)`** - Creates a minimal mock Agent with messages and state. Use for testing components that need an Agent reference without full agent behavior. + +```typescript +const agent = createMockAgent({ + messages: [new Message({ role: 'user', content: [new TextBlock('Hi')] })], + state: { key: 'value' }, +}) +``` + +- **`expectAgentResult(options)`** - Creates an asymmetric matcher that validates `AgentResult` structure and values. Reduces deeply nested assertions by providing a clean, readable matcher that combines stop reason, message text, metrics, and traces validation. + +```typescript +import { expectAgentResult } from '../__fixtures__/agent-helpers' + +// ✅ RECOMMENDED - Clean, readable assertion +expect(result).toEqual( + expectAgentResult({ + stopReason: 'endTurn', + messageText: 'Hello, world!', + cycleCount: 1, + traceCount: 1, + }) +) + +// ✅ With tools and detailed metrics +expect(result).toEqual( + expectAgentResult({ + stopReason: 'endTurn', + messageText: 'The answer is 42', + cycleCount: 2, + toolNames: ['calculator'], + traceCount: 2, + usage: { inputTokens: 300, outputTokens: 80, totalTokens: 380 }, + }) +) + +// ✅ For detailed trace structure validation, follow up with specific assertions +expect(result).toEqual( + expectAgentResult({ + stopReason: 'endTurn', + messageText: 'Done', + cycleCount: 2, + toolNames: ['calc'], + }) +) +// Verify detailed trace structure +expect(result.traces).toEqual([ + expect.objectContaining({ + name: 'Cycle 1', + children: expect.arrayContaining([ + expect.objectContaining({ name: 'stream_messages' }), + expect.objectContaining({ name: 'Tool: calc' }), + ]), + }), + expect.objectContaining({ + name: 'Cycle 2', + children: expect.arrayContaining([expect.objectContaining({ name: 'stream_messages' })]), + }), +]) + +// ❌ AVOID - Deeply nested, hard to read +expect(result).toEqual( + expect.objectContaining({ + stopReason: 'endTurn', + lastMessage: expect.objectContaining({ + role: 'assistant', + content: expect.arrayContaining([expect.objectContaining({ type: 'textBlock', text: 'Hello' })]), + }), + metrics: expectLoopMetrics({ cycleCount: 1 }), + traces: expect.arrayContaining([expect.objectContaining({ name: 'Cycle 1' })]), + }) +) +``` + +**Options:** + +- `stopReason` (required) - Expected stop reason ('endTurn', 'toolUse', 'maxTokens') +- `messageText` (optional) - Expected text content in last assistant message's TextBlock. When omitted, only validates message exists with role 'assistant' +- `cycleCount` (required) - Expected number of agent loop cycles +- `traceCount` (optional) - Expected exact number of traces. When omitted, validates at least one trace exists +- `toolNames` (optional) - Expected tool names that were invoked +- `usage` (optional) - Expected token usage. When omitted, validates shape with `expect.any(Number)` + +- **`createCancellableAgent(id, delayMs, structuredOutput?)`** - Creates a minimal `InvokableAgent` that sleeps for `delayMs` before resolving, aborting the sleep early when the invocation's `cancelSignal` fires. Use for exercising timeout and cancellation behavior in multi-agent orchestrators (swarm, graph) without standing up a full Agent. + +```typescript +import { createCancellableAgent } from '../__fixtures__/agent-helpers' + +// Plain slow agent for a nodeTimeout test +const slow = createCancellableAgent('slow', 100) + +// With a swarm handoff as structured output +const handoffAgent = createCancellableAgent('a', 30, { agentId: 'b', message: 'to b' }) +``` + +### Environment Fixtures (`environment.ts`) + +- **`isNode`** - Boolean that detects if running in Node.js environment. +- **`isBrowser`** - Boolean that detects if running in a browser environment. + +Use these for conditional test execution when tests depend on environment-specific features. + +```typescript +import { isNode } from '../__fixtures__/environment' + +// Skip tests that require Node.js features in browser +describe.skipIf(!isNode)('Node.js specific features', () => { + it('uses environment variables', () => { + expect(process.env.NODE_ENV).toBeDefined() + }) +}) +``` + +### Telemetry Fixtures (`mock-span.ts`, `mock-meter.ts`) + +- **`MockSpan`** - Implements the OTEL `Span` interface and records all calls (`setAttribute`, `addEvent`, `setStatus`, `end`, `recordException`) for assertion. Use with `vi.mock('@opentelemetry/api')` to intercept tracer span creation. + - Access `mockSpan.calls.setAttribute` etc. to verify recorded calls. + - Use `mockSpan.getAttributeValue(key)` to look up a specific attribute. + - Use `mockSpan.getEvents(name)` to filter events by name. +- **`eventAttr(event, key)`** - Extracts a string attribute from a mock span event's attributes map. +- **`MockMeter`** - Implements the OTEL `Meter` interface and records all instrument data points. Use with `vi.spyOn(otelMetrics, 'getMeter').mockReturnValue(mockMeter)` to intercept meter creation. + - Use `mockMeter.getCounter(name)` to retrieve a counter by metric name. + - Use `mockMeter.getHistogram(name)` to retrieve a histogram by metric name. + - Counters and histograms expose `.dataPoints` (array of `{ value, attributes }`) and `.sum` (total of all values). + +```typescript +import { MockSpan, eventAttr } from '../__fixtures__/mock-span' + +// Mock the OTEL API and inject MockSpan +const mockSpan = new MockSpan() +const mockStartSpan = vi.fn().mockReturnValue(mockSpan) +vi.mocked(trace.getTracer).mockReturnValue({ startSpan: mockStartSpan, startActiveSpan: vi.fn() }) + +// Assert on span attributes and events +expect(mockSpan.getAttributeValue('gen_ai.agent.name')).toBe('test-agent') +expect(mockSpan.getEvents('gen_ai.user.message')).toHaveLength(1) +expect(eventAttr(mockSpan.getEvents('gen_ai.choice')[0]!, 'finish_reason')).toBe('end_turn') +``` + +```typescript +import { MockMeter } from '../__fixtures__/mock-meter' + +// Mock the OTEL API and inject MockMeter +const mockMeter = new MockMeter() +vi.spyOn(otelMetrics, 'getMeter').mockReturnValue(mockMeter) + +const m = new Meter() +m.startNewInvocation() + +// Assert on collected metric values +expect(mockMeter.getCounter('gen_ai.agent.invocation.count')?.sum).toBe(1) +expect(mockMeter.getHistogram('gen_ai.agent.cycle.duration')?.sum).toBe(2000) +expect(mockMeter.getCounter('gen_ai.agent.tool.call.count')?.dataPoints).toStrictEqual([ + { value: 1, attributes: { 'gen_ai.tool.name': 'search' } }, +]) +``` + +### Metrics Fixtures (`metrics-helpers.ts`) + +- **`expectLoopMetrics({ cycleCount, toolNames?, usage? })`** - Creates an asymmetric matcher that validates `AgentMetrics` structure and values. When `usage` is provided, asserts exact token counts. When omitted, falls back to shape-level assertions with `expect.any(Number)`. +- **`findMetricValue(resourceMetrics, metricName)`** - Flattens the OTEL ResourceMetrics → ScopeMetrics → MetricData hierarchy and returns the value of the last data point for the matching metric name. Returns `undefined` if not found. + +```typescript +import { expectLoopMetrics } from '../__fixtures__/metrics-helpers' + +// Shape-level assertion (no concrete token counts) +expect(result).toEqual( + new AgentResult({ + stopReason: 'endTurn', + lastMessage: expect.objectContaining({ role: 'assistant' }), + metrics: expectLoopMetrics({ cycleCount: 1 }), + }) +) + +// With tool names +expect(result).toEqual( + new AgentResult({ + stopReason: 'endTurn', + lastMessage: expect.objectContaining({ role: 'assistant' }), + metrics: expectLoopMetrics({ cycleCount: 2, toolNames: ['calc'] }), + }) +) + +// With concrete token usage (pair with MockMessageModel usage param) +expect(result).toEqual( + new AgentResult({ + stopReason: 'endTurn', + lastMessage: expect.objectContaining({ role: 'assistant' }), + metrics: expectLoopMetrics({ + cycleCount: 2, + toolNames: ['calc'], + usage: { inputTokens: 300, outputTokens: 80, totalTokens: 380 }, + }), + }) +) +``` + +```typescript +import { findMetricValue } from '../__fixtures__/metrics-helpers' + +// Find a counter value from OTEL InMemoryMetricExporter output +const cycleCount = findMetricValue(metricExporter.getMetrics(), 'gen_ai.agent.cycle.count') +expect(cycleCount).toBeGreaterThanOrEqual(1) + +// Check a histogram was emitted +const duration = findMetricValue(metrics, 'gen_ai.agent.cycle.duration') +expect(duration).toBeDefined() +``` + +## Multi-Environment Testing + +The SDK is designed to work seamlessly in both Node.js and browser environments. Our test suite validates this by running tests in both environments using Vitest's browser mode with Playwright. + +### Test Projects + +The test suite is organized into three projects: + +1. **unit-node** (green): Unit tests running in Node.js environment +2. **unit-browser** (cyan): Same unit tests running in Chromium browser +3. **integ** (magenta): Integration tests running in Node.js + +### Environment-Specific Test Patterns + +- You MUST write tests that are environment-agnostic unless they depend on Node.js features like filesystem or env-vars + +Some tests require Node.js-specific features (like process.env, AWS SDK) and should be skipped in browser environments: + +```typescript +import { describe, it, expect } from 'vitest' +import { isNode } from '../__fixtures__/environment' + +// Tests will run in Node.js, skip in browser +describe.skipIf(!isNode)('Node.js specific features', () => { + it('uses environment variables', () => { + // This test accesses process.env + expect(process.env.NODE_ENV).toBeDefined() + }) +}) +``` + +### Environment Variable Stubbing + +When stubbing environment variables with `vi.stubEnv()`, you do **not** need to wrap calls in `if (isNode)` conditions. Vitest handles this automatically across environments, and the vitest config has `unstubEnvs: true` which restores env vars after each test. + +```typescript +// ✅ CORRECT - No condition needed +beforeEach(() => { + vi.stubEnv('API_KEY', 'test-key') +}) + +// ❌ WRONG - Unnecessary condition +beforeEach(() => { + if (isNode) { + vi.stubEnv('API_KEY', 'test-key') + } +}) +``` + +Similarly, you do **not** need to call `vi.unstubAllEnvs()` in `afterEach` since the vitest config handles this automatically. + +## Development Commands + +```bash +npm test # Run unit tests in Node.js +npm run test:browser # Run unit tests in browser (Chromium via Playwright) +npm run test:all # Run all tests in all environments +npm run test:integ # Run integration tests +npm run test:coverage # Run tests with coverage report +``` + +For detailed command usage, see [CONTRIBUTING.md - Testing Instructions](../CONTRIBUTING.md#testing-instructions-and-best-practices). + +## Checklist Items + +- [ ] Do the tests use relevant helpers from `src/__fixtures__/` as noted in the "Test Fixtures Quick Reference" table above? +- [ ] Are recurring code or patterns extracted to functions for better usability/readability? +- [ ] Are tests focused on verifying one or two things only? +- [ ] Are tests written concisely enough that the bulk of each test is important to the test instead of boilerplate code? +- [ ] Are tests asserting on the entire object instead of specific fields? diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 0000000000..6ce38e70bf --- /dev/null +++ b/package-lock.json @@ -0,0 +1,9075 @@ +{ + "name": "strands", + "version": "0.0.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "strands", + "version": "0.0.0", + "workspaces": [ + "strandly", + "strands-ts", + "strands-wasm" + ], + "devDependencies": { + "husky": "^9.1.7", + "prettier": "^3.7.4" + } + }, + "node_modules/@a2a-js/sdk": { + "version": "0.3.13", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "uuid": "^11.1.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "@bufbuild/protobuf": "^2.10.2", + "@grpc/grpc-js": "^1.11.0", + "express": "^4.21.2 || ^5.1.0" + }, + "peerDependenciesMeta": { + "@bufbuild/protobuf": { + "optional": true + }, + "@grpc/grpc-js": { + "optional": true + }, + "express": { + "optional": true + } + } + }, + "node_modules/@a2a-js/sdk/node_modules/uuid": { + "version": "11.1.1", + "resolved": "https://registry.npmjs.org/uuid/-/uuid-11.1.1.tgz", + "integrity": "sha512-vIYxrBCC/N/K+Js3qSN88go7kIfNPssr/hHCesKCQNAjmgvYS2oqr69kIufEG+O4+PfezOH4EbIeHCfFov8ZgQ==", + "dev": true, + "funding": [ + "https://github.com/sponsors/broofa", + "https://github.com/sponsors/ctavan" + ], + "license": "MIT", + "bin": { + "uuid": "dist/esm/bin/uuid" + } + }, + "node_modules/@ai-sdk/amazon-bedrock": { + "version": "4.0.96", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/anthropic": "3.0.71", + "@ai-sdk/provider": "3.0.8", + "@ai-sdk/provider-utils": "4.0.23", + "@smithy/eventstream-codec": "^4.0.1", + "@smithy/util-utf8": "^4.0.0", + "aws4fetch": "^1.0.20" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/anthropic": { + "version": "3.0.71", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.8", + "@ai-sdk/provider-utils": "4.0.23" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/openai": { + "version": "3.0.53", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.8", + "@ai-sdk/provider-utils": "4.0.23" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/provider": { + "version": "3.0.8", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "json-schema": "^0.4.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@ai-sdk/provider-utils": { + "version": "4.0.23", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.8", + "@standard-schema/spec": "^1.1.0", + "eventsource-parser": "^3.0.6" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@anthropic-ai/sdk": { + "version": "0.92.0", + "dev": true, + "license": "MIT", + "dependencies": { + "json-schema-to-ts": "^3.1.1" + }, + "bin": { + "anthropic-ai-sdk": "bin/cli" + }, + "peerDependencies": { + "zod": "^3.25.0 || ^4.0.0" + }, + "peerDependenciesMeta": { + "zod": { + "optional": true + } + } + }, + "node_modules/@aws-crypto/crc32": { + "version": "5.2.0", + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/util": "^5.2.0", + "@aws-sdk/types": "^3.222.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=16.0.0" + } + }, + "node_modules/@aws-crypto/crc32c": { + "version": "5.2.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/util": "^5.2.0", + "@aws-sdk/types": "^3.222.0", + "tslib": "^2.6.2" + } + }, + "node_modules/@aws-crypto/sha1-browser": { + "version": "5.2.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/supports-web-crypto": "^5.2.0", + "@aws-crypto/util": "^5.2.0", + "@aws-sdk/types": "^3.222.0", + "@aws-sdk/util-locate-window": "^3.0.0", + "@smithy/util-utf8": "^2.0.0", + "tslib": "^2.6.2" + } + }, + "node_modules/@aws-crypto/sha1-browser/node_modules/@smithy/is-array-buffer": { + "version": "2.2.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@aws-crypto/sha1-browser/node_modules/@smithy/util-buffer-from": { + "version": "2.2.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@smithy/is-array-buffer": "^2.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@aws-crypto/sha1-browser/node_modules/@smithy/util-utf8": { + "version": "2.3.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@smithy/util-buffer-from": "^2.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@aws-crypto/sha256-browser": { + "version": "5.2.0", + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/sha256-js": "^5.2.0", + "@aws-crypto/supports-web-crypto": "^5.2.0", + "@aws-crypto/util": "^5.2.0", + "@aws-sdk/types": "^3.222.0", + "@aws-sdk/util-locate-window": "^3.0.0", + "@smithy/util-utf8": "^2.0.0", + "tslib": "^2.6.2" + } + }, + "node_modules/@aws-crypto/sha256-browser/node_modules/@smithy/is-array-buffer": { + "version": "2.2.0", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@aws-crypto/sha256-browser/node_modules/@smithy/util-buffer-from": { + "version": "2.2.0", + "license": "Apache-2.0", + "dependencies": { + "@smithy/is-array-buffer": "^2.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@aws-crypto/sha256-browser/node_modules/@smithy/util-utf8": { + "version": "2.3.0", + "license": "Apache-2.0", + "dependencies": { + "@smithy/util-buffer-from": "^2.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@aws-crypto/sha256-js": { + "version": "5.2.0", + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/util": "^5.2.0", + "@aws-sdk/types": "^3.222.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=16.0.0" + } + }, + "node_modules/@aws-crypto/supports-web-crypto": { + "version": "5.2.0", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + } + }, + "node_modules/@aws-crypto/util": { + "version": "5.2.0", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "^3.222.0", + "@smithy/util-utf8": "^2.0.0", + "tslib": "^2.6.2" + } + }, + "node_modules/@aws-crypto/util/node_modules/@smithy/is-array-buffer": { + "version": "2.2.0", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@aws-crypto/util/node_modules/@smithy/util-buffer-from": { + "version": "2.2.0", + "license": "Apache-2.0", + "dependencies": { + "@smithy/is-array-buffer": "^2.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@aws-crypto/util/node_modules/@smithy/util-utf8": { + "version": "2.3.0", + "license": "Apache-2.0", + "dependencies": { + "@smithy/util-buffer-from": "^2.2.0", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@aws-sdk/client-bedrock": { + "version": "3.1033.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/sha256-browser": "5.2.0", + "@aws-crypto/sha256-js": "5.2.0", + "@aws-sdk/core": "^3.974.2", + "@aws-sdk/credential-provider-node": "^3.972.33", + "@aws-sdk/middleware-host-header": "^3.972.10", + "@aws-sdk/middleware-logger": "^3.972.10", + "@aws-sdk/middleware-recursion-detection": "^3.972.11", + "@aws-sdk/middleware-user-agent": "^3.972.32", + "@aws-sdk/region-config-resolver": "^3.972.12", + "@aws-sdk/token-providers": "3.1033.0", + "@aws-sdk/types": "^3.973.8", + "@aws-sdk/util-endpoints": "^3.996.7", + "@aws-sdk/util-user-agent-browser": "^3.972.10", + "@aws-sdk/util-user-agent-node": "^3.973.18", + "@smithy/config-resolver": "^4.4.16", + "@smithy/core": "^3.23.15", + "@smithy/fetch-http-handler": "^5.3.17", + "@smithy/hash-node": "^4.2.14", + "@smithy/invalid-dependency": "^4.2.14", + "@smithy/middleware-content-length": "^4.2.14", + "@smithy/middleware-endpoint": "^4.4.30", + "@smithy/middleware-retry": "^4.5.3", + "@smithy/middleware-serde": "^4.2.18", + "@smithy/middleware-stack": "^4.2.14", + "@smithy/node-config-provider": "^4.3.14", + "@smithy/node-http-handler": "^4.5.3", + "@smithy/protocol-http": "^5.3.14", + "@smithy/smithy-client": "^4.12.11", + "@smithy/types": "^4.14.1", + "@smithy/url-parser": "^4.2.14", + "@smithy/util-base64": "^4.3.2", + "@smithy/util-body-length-browser": "^4.2.2", + "@smithy/util-body-length-node": "^4.2.3", + "@smithy/util-defaults-mode-browser": "^4.3.47", + "@smithy/util-defaults-mode-node": "^4.2.52", + "@smithy/util-endpoints": "^3.4.1", + "@smithy/util-middleware": "^4.2.14", + "@smithy/util-retry": "^4.3.2", + "@smithy/util-utf8": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/client-bedrock-runtime": { + "version": "3.1037.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/client-bedrock-runtime/-/client-bedrock-runtime-3.1037.0.tgz", + "integrity": "sha512-Evla4DUdBf1pQpQa7pbfquj7jRaRktkI0qGoWBJBXWB9wQISzJ8OEI4sHugk/W6SF47C7hMP/o3Z/XBrfnejCw==", + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/sha256-browser": "5.2.0", + "@aws-crypto/sha256-js": "5.2.0", + "@aws-sdk/core": "^3.974.5", + "@aws-sdk/credential-provider-node": "^3.972.36", + "@aws-sdk/eventstream-handler-node": "^3.972.14", + "@aws-sdk/middleware-eventstream": "^3.972.10", + "@aws-sdk/middleware-host-header": "^3.972.10", + "@aws-sdk/middleware-logger": "^3.972.10", + "@aws-sdk/middleware-recursion-detection": "^3.972.11", + "@aws-sdk/middleware-user-agent": "^3.972.35", + "@aws-sdk/middleware-websocket": "^3.972.16", + "@aws-sdk/region-config-resolver": "^3.972.13", + "@aws-sdk/token-providers": "3.1037.0", + "@aws-sdk/types": "^3.973.8", + "@aws-sdk/util-endpoints": "^3.996.8", + "@aws-sdk/util-user-agent-browser": "^3.972.10", + "@aws-sdk/util-user-agent-node": "^3.973.21", + "@smithy/config-resolver": "^4.4.17", + "@smithy/core": "^3.23.17", + "@smithy/eventstream-serde-browser": "^4.2.14", + "@smithy/eventstream-serde-config-resolver": "^4.3.14", + "@smithy/eventstream-serde-node": "^4.2.14", + "@smithy/fetch-http-handler": "^5.3.17", + "@smithy/hash-node": "^4.2.14", + "@smithy/invalid-dependency": "^4.2.14", + "@smithy/middleware-content-length": "^4.2.14", + "@smithy/middleware-endpoint": "^4.4.32", + "@smithy/middleware-retry": "^4.5.5", + "@smithy/middleware-serde": "^4.2.20", + "@smithy/middleware-stack": "^4.2.14", + "@smithy/node-config-provider": "^4.3.14", + "@smithy/node-http-handler": "^4.6.1", + "@smithy/protocol-http": "^5.3.14", + "@smithy/smithy-client": "^4.12.13", + "@smithy/types": "^4.14.1", + "@smithy/url-parser": "^4.2.14", + "@smithy/util-base64": "^4.3.2", + "@smithy/util-body-length-browser": "^4.2.2", + "@smithy/util-body-length-node": "^4.2.3", + "@smithy/util-defaults-mode-browser": "^4.3.49", + "@smithy/util-defaults-mode-node": "^4.2.54", + "@smithy/util-endpoints": "^3.4.2", + "@smithy/util-middleware": "^4.2.14", + "@smithy/util-retry": "^4.3.4", + "@smithy/util-stream": "^4.5.25", + "@smithy/util-utf8": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/client-bedrock-runtime/node_modules/@aws-sdk/token-providers": { + "version": "3.1037.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/token-providers/-/token-providers-3.1037.0.tgz", + "integrity": "sha512-csxa484KboWLs3f8jFQ5v9RwH8FVf0fQ+SO3GSXyu4Jtinhh4qXmOWLSVX30RBpB933dZaKGHGEXzEEY88NqRw==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "^3.974.5", + "@aws-sdk/nested-clients": "^3.997.3", + "@aws-sdk/types": "^3.973.8", + "@smithy/property-provider": "^4.2.14", + "@smithy/shared-ini-file-loader": "^4.4.9", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/client-cognito-identity": { + "version": "3.1033.0", + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/sha256-browser": "5.2.0", + "@aws-crypto/sha256-js": "5.2.0", + "@aws-sdk/core": "^3.974.2", + "@aws-sdk/credential-provider-node": "^3.972.33", + "@aws-sdk/middleware-host-header": "^3.972.10", + "@aws-sdk/middleware-logger": "^3.972.10", + "@aws-sdk/middleware-recursion-detection": "^3.972.11", + "@aws-sdk/middleware-user-agent": "^3.972.32", + "@aws-sdk/region-config-resolver": "^3.972.12", + "@aws-sdk/types": "^3.973.8", + "@aws-sdk/util-endpoints": "^3.996.7", + "@aws-sdk/util-user-agent-browser": "^3.972.10", + "@aws-sdk/util-user-agent-node": "^3.973.18", + "@smithy/config-resolver": "^4.4.16", + "@smithy/core": "^3.23.15", + "@smithy/fetch-http-handler": "^5.3.17", + "@smithy/hash-node": "^4.2.14", + "@smithy/invalid-dependency": "^4.2.14", + "@smithy/middleware-content-length": "^4.2.14", + "@smithy/middleware-endpoint": "^4.4.30", + "@smithy/middleware-retry": "^4.5.3", + "@smithy/middleware-serde": "^4.2.18", + "@smithy/middleware-stack": "^4.2.14", + "@smithy/node-config-provider": "^4.3.14", + "@smithy/node-http-handler": "^4.5.3", + "@smithy/protocol-http": "^5.3.14", + "@smithy/smithy-client": "^4.12.11", + "@smithy/types": "^4.14.1", + "@smithy/url-parser": "^4.2.14", + "@smithy/util-base64": "^4.3.2", + "@smithy/util-body-length-browser": "^4.2.2", + "@smithy/util-body-length-node": "^4.2.3", + "@smithy/util-defaults-mode-browser": "^4.3.47", + "@smithy/util-defaults-mode-node": "^4.2.52", + "@smithy/util-endpoints": "^3.4.1", + "@smithy/util-middleware": "^4.2.14", + "@smithy/util-retry": "^4.3.2", + "@smithy/util-utf8": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/client-s3": { + "version": "3.1033.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/sha1-browser": "5.2.0", + "@aws-crypto/sha256-browser": "5.2.0", + "@aws-crypto/sha256-js": "5.2.0", + "@aws-sdk/core": "^3.974.2", + "@aws-sdk/credential-provider-node": "^3.972.33", + "@aws-sdk/middleware-bucket-endpoint": "^3.972.10", + "@aws-sdk/middleware-expect-continue": "^3.972.10", + "@aws-sdk/middleware-flexible-checksums": "^3.974.10", + "@aws-sdk/middleware-host-header": "^3.972.10", + "@aws-sdk/middleware-location-constraint": "^3.972.10", + "@aws-sdk/middleware-logger": "^3.972.10", + "@aws-sdk/middleware-recursion-detection": "^3.972.11", + "@aws-sdk/middleware-sdk-s3": "^3.972.31", + "@aws-sdk/middleware-ssec": "^3.972.10", + "@aws-sdk/middleware-user-agent": "^3.972.32", + "@aws-sdk/region-config-resolver": "^3.972.12", + "@aws-sdk/signature-v4-multi-region": "^3.996.19", + "@aws-sdk/types": "^3.973.8", + "@aws-sdk/util-endpoints": "^3.996.7", + "@aws-sdk/util-user-agent-browser": "^3.972.10", + "@aws-sdk/util-user-agent-node": "^3.973.18", + "@smithy/config-resolver": "^4.4.16", + "@smithy/core": "^3.23.15", + "@smithy/eventstream-serde-browser": "^4.2.14", + "@smithy/eventstream-serde-config-resolver": "^4.3.14", + "@smithy/eventstream-serde-node": "^4.2.14", + "@smithy/fetch-http-handler": "^5.3.17", + "@smithy/hash-blob-browser": "^4.2.15", + "@smithy/hash-node": "^4.2.14", + "@smithy/hash-stream-node": "^4.2.14", + "@smithy/invalid-dependency": "^4.2.14", + "@smithy/md5-js": "^4.2.14", + "@smithy/middleware-content-length": "^4.2.14", + "@smithy/middleware-endpoint": "^4.4.30", + "@smithy/middleware-retry": "^4.5.3", + "@smithy/middleware-serde": "^4.2.18", + "@smithy/middleware-stack": "^4.2.14", + "@smithy/node-config-provider": "^4.3.14", + "@smithy/node-http-handler": "^4.5.3", + "@smithy/protocol-http": "^5.3.14", + "@smithy/smithy-client": "^4.12.11", + "@smithy/types": "^4.14.1", + "@smithy/url-parser": "^4.2.14", + "@smithy/util-base64": "^4.3.2", + "@smithy/util-body-length-browser": "^4.2.2", + "@smithy/util-body-length-node": "^4.2.3", + "@smithy/util-defaults-mode-browser": "^4.3.47", + "@smithy/util-defaults-mode-node": "^4.2.52", + "@smithy/util-endpoints": "^3.4.1", + "@smithy/util-middleware": "^4.2.14", + "@smithy/util-retry": "^4.3.2", + "@smithy/util-stream": "^4.5.23", + "@smithy/util-utf8": "^4.2.2", + "@smithy/util-waiter": "^4.2.16", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/client-secrets-manager": { + "version": "3.1033.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/sha256-browser": "5.2.0", + "@aws-crypto/sha256-js": "5.2.0", + "@aws-sdk/core": "^3.974.2", + "@aws-sdk/credential-provider-node": "^3.972.33", + "@aws-sdk/middleware-host-header": "^3.972.10", + "@aws-sdk/middleware-logger": "^3.972.10", + "@aws-sdk/middleware-recursion-detection": "^3.972.11", + "@aws-sdk/middleware-user-agent": "^3.972.32", + "@aws-sdk/region-config-resolver": "^3.972.12", + "@aws-sdk/types": "^3.973.8", + "@aws-sdk/util-endpoints": "^3.996.7", + "@aws-sdk/util-user-agent-browser": "^3.972.10", + "@aws-sdk/util-user-agent-node": "^3.973.18", + "@smithy/config-resolver": "^4.4.16", + "@smithy/core": "^3.23.15", + "@smithy/fetch-http-handler": "^5.3.17", + "@smithy/hash-node": "^4.2.14", + "@smithy/invalid-dependency": "^4.2.14", + "@smithy/middleware-content-length": "^4.2.14", + "@smithy/middleware-endpoint": "^4.4.30", + "@smithy/middleware-retry": "^4.5.3", + "@smithy/middleware-serde": "^4.2.18", + "@smithy/middleware-stack": "^4.2.14", + "@smithy/node-config-provider": "^4.3.14", + "@smithy/node-http-handler": "^4.5.3", + "@smithy/protocol-http": "^5.3.14", + "@smithy/smithy-client": "^4.12.11", + "@smithy/types": "^4.14.1", + "@smithy/url-parser": "^4.2.14", + "@smithy/util-base64": "^4.3.2", + "@smithy/util-body-length-browser": "^4.2.2", + "@smithy/util-body-length-node": "^4.2.3", + "@smithy/util-defaults-mode-browser": "^4.3.47", + "@smithy/util-defaults-mode-node": "^4.2.52", + "@smithy/util-endpoints": "^3.4.1", + "@smithy/util-middleware": "^4.2.14", + "@smithy/util-retry": "^4.3.2", + "@smithy/util-utf8": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/client-sts": { + "version": "3.1033.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/sha256-browser": "5.2.0", + "@aws-crypto/sha256-js": "5.2.0", + "@aws-sdk/core": "^3.974.2", + "@aws-sdk/credential-provider-node": "^3.972.33", + "@aws-sdk/middleware-host-header": "^3.972.10", + "@aws-sdk/middleware-logger": "^3.972.10", + "@aws-sdk/middleware-recursion-detection": "^3.972.11", + "@aws-sdk/middleware-user-agent": "^3.972.32", + "@aws-sdk/region-config-resolver": "^3.972.12", + "@aws-sdk/signature-v4-multi-region": "^3.996.19", + "@aws-sdk/types": "^3.973.8", + "@aws-sdk/util-endpoints": "^3.996.7", + "@aws-sdk/util-user-agent-browser": "^3.972.10", + "@aws-sdk/util-user-agent-node": "^3.973.18", + "@smithy/config-resolver": "^4.4.16", + "@smithy/core": "^3.23.15", + "@smithy/fetch-http-handler": "^5.3.17", + "@smithy/hash-node": "^4.2.14", + "@smithy/invalid-dependency": "^4.2.14", + "@smithy/middleware-content-length": "^4.2.14", + "@smithy/middleware-endpoint": "^4.4.30", + "@smithy/middleware-retry": "^4.5.3", + "@smithy/middleware-serde": "^4.2.18", + "@smithy/middleware-stack": "^4.2.14", + "@smithy/node-config-provider": "^4.3.14", + "@smithy/node-http-handler": "^4.5.3", + "@smithy/protocol-http": "^5.3.14", + "@smithy/smithy-client": "^4.12.11", + "@smithy/types": "^4.14.1", + "@smithy/url-parser": "^4.2.14", + "@smithy/util-base64": "^4.3.2", + "@smithy/util-body-length-browser": "^4.2.2", + "@smithy/util-body-length-node": "^4.2.3", + "@smithy/util-defaults-mode-browser": "^4.3.47", + "@smithy/util-defaults-mode-node": "^4.2.52", + "@smithy/util-endpoints": "^3.4.1", + "@smithy/util-middleware": "^4.2.14", + "@smithy/util-retry": "^4.3.2", + "@smithy/util-utf8": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/core": { + "version": "3.974.8", + "resolved": "https://registry.npmjs.org/@aws-sdk/core/-/core-3.974.8.tgz", + "integrity": "sha512-njR2qoG6ZuB0kvAS2FyICsFZJ6gmCcf2X/7JcD14sUvGDm26wiZ5BrA6LOiUxKFEF+IVe7kdroxyE00YlkiYsw==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "^3.973.8", + "@aws-sdk/xml-builder": "^3.972.22", + "@smithy/core": "^3.23.17", + "@smithy/node-config-provider": "^4.3.14", + "@smithy/property-provider": "^4.2.14", + "@smithy/protocol-http": "^5.3.14", + "@smithy/signature-v4": "^5.3.14", + "@smithy/smithy-client": "^4.12.13", + "@smithy/types": "^4.14.1", + "@smithy/util-base64": "^4.3.2", + "@smithy/util-middleware": "^4.2.14", + "@smithy/util-retry": "^4.3.6", + "@smithy/util-utf8": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/crc64-nvme": { + "version": "3.972.7", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/credential-provider-cognito-identity": { + "version": "3.972.25", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/nested-clients": "^3.997.0", + "@aws-sdk/types": "^3.973.8", + "@smithy/property-provider": "^4.2.14", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/credential-provider-env": { + "version": "3.972.34", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-env/-/credential-provider-env-3.972.34.tgz", + "integrity": "sha512-XT0jtf8Fw9JE6ppsQeoNnZRiG+jqRixMT1v1ZR17G60UvVdsQmTG8nbEyHuEPfMxDXEhfdARaM/XiEhca4lGHQ==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "^3.974.8", + "@aws-sdk/types": "^3.973.8", + "@smithy/property-provider": "^4.2.14", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/credential-provider-http": { + "version": "3.972.36", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-http/-/credential-provider-http-3.972.36.tgz", + "integrity": "sha512-DPoGWfy7J7RKxvbf5kOKIGQkD2ek3dbKgzKIGrnLuvZBz5myU+Im/H6pmc14QcnFbqHMqxvtWSgRDSJW3qXLQg==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "^3.974.8", + "@aws-sdk/types": "^3.973.8", + "@smithy/fetch-http-handler": "^5.3.17", + "@smithy/node-http-handler": "^4.6.1", + "@smithy/property-provider": "^4.2.14", + "@smithy/protocol-http": "^5.3.14", + "@smithy/smithy-client": "^4.12.13", + "@smithy/types": "^4.14.1", + "@smithy/util-stream": "^4.5.25", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/credential-provider-ini": { + "version": "3.972.38", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-ini/-/credential-provider-ini-3.972.38.tgz", + "integrity": "sha512-oDzUBu2MGJFgoar05sPMCwSrhw44ASyccrHzj66vO69OZqi7I6hZZxXfuPLC8OCzW7C+sU+bI73XHij41yekgQ==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "^3.974.8", + "@aws-sdk/credential-provider-env": "^3.972.34", + "@aws-sdk/credential-provider-http": "^3.972.36", + "@aws-sdk/credential-provider-login": "^3.972.38", + "@aws-sdk/credential-provider-process": "^3.972.34", + "@aws-sdk/credential-provider-sso": "^3.972.38", + "@aws-sdk/credential-provider-web-identity": "^3.972.38", + "@aws-sdk/nested-clients": "^3.997.6", + "@aws-sdk/types": "^3.973.8", + "@smithy/credential-provider-imds": "^4.2.14", + "@smithy/property-provider": "^4.2.14", + "@smithy/shared-ini-file-loader": "^4.4.9", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/credential-provider-login": { + "version": "3.972.38", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-login/-/credential-provider-login-3.972.38.tgz", + "integrity": "sha512-g1NosS8qe4OF++G2UFCM5ovSkgipC7YYor5KCWatG0UoMSO5YFj9C8muePlyVmOBV/WTI16Jo3/s1NUo/o1Bww==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "^3.974.8", + "@aws-sdk/nested-clients": "^3.997.6", + "@aws-sdk/types": "^3.973.8", + "@smithy/property-provider": "^4.2.14", + "@smithy/protocol-http": "^5.3.14", + "@smithy/shared-ini-file-loader": "^4.4.9", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/credential-provider-node": { + "version": "3.972.39", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-node/-/credential-provider-node-3.972.39.tgz", + "integrity": "sha512-HEswDQyxUtadoZ/bJsPPENHg7R0Lzym5LuMksJeHvqhCOpP+rtkDLKI4/ZChH4w3cf5kG8n6bZuI8PzajoiqMg==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/credential-provider-env": "^3.972.34", + "@aws-sdk/credential-provider-http": "^3.972.36", + "@aws-sdk/credential-provider-ini": "^3.972.38", + "@aws-sdk/credential-provider-process": "^3.972.34", + "@aws-sdk/credential-provider-sso": "^3.972.38", + "@aws-sdk/credential-provider-web-identity": "^3.972.38", + "@aws-sdk/types": "^3.973.8", + "@smithy/credential-provider-imds": "^4.2.14", + "@smithy/property-provider": "^4.2.14", + "@smithy/shared-ini-file-loader": "^4.4.9", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/credential-provider-process": { + "version": "3.972.34", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-process/-/credential-provider-process-3.972.34.tgz", + "integrity": "sha512-T3IFs4EVmVi1dVN5RciFnklCANSzvrQd/VuHY9ThHSQmYkTogjcGkoJEr+oNUPQZnso52183088NqysMPji1/Q==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "^3.974.8", + "@aws-sdk/types": "^3.973.8", + "@smithy/property-provider": "^4.2.14", + "@smithy/shared-ini-file-loader": "^4.4.9", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/credential-provider-sso": { + "version": "3.972.38", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-sso/-/credential-provider-sso-3.972.38.tgz", + "integrity": "sha512-5ZxG+t0+3Q3QPh8KEjX6syskhgNf7I0MN7oGioTf6Lm1NTjfP7sIcYGNsthXC2qR8vcD3edNZwCr2ovfSSWuRA==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "^3.974.8", + "@aws-sdk/nested-clients": "^3.997.6", + "@aws-sdk/token-providers": "3.1041.0", + "@aws-sdk/types": "^3.973.8", + "@smithy/property-provider": "^4.2.14", + "@smithy/shared-ini-file-loader": "^4.4.9", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/credential-provider-sso/node_modules/@aws-sdk/token-providers": { + "version": "3.1041.0", + "resolved": "https://registry.npmjs.org/@aws-sdk/token-providers/-/token-providers-3.1041.0.tgz", + "integrity": "sha512-Th7kPI6YPtvJUcdznooXJMy+9rQWjmEF81LxaJssngBzuysK4a/x+l8kjm1zb7nYsUPbndnBdUnwng/3PLvtGw==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "^3.974.8", + "@aws-sdk/nested-clients": "^3.997.6", + "@aws-sdk/types": "^3.973.8", + "@smithy/property-provider": "^4.2.14", + "@smithy/shared-ini-file-loader": "^4.4.9", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/credential-provider-web-identity": { + "version": "3.972.38", + "resolved": "https://registry.npmjs.org/@aws-sdk/credential-provider-web-identity/-/credential-provider-web-identity-3.972.38.tgz", + "integrity": "sha512-lYHFF30DGI20jZcYX8cm6Ns0V7f1dDN6g/MBDLTyD/5iw+bXs3yBr2iAiHDkx4RFU5JgsnZvCHYKiRVPRdmOgw==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "^3.974.8", + "@aws-sdk/nested-clients": "^3.997.6", + "@aws-sdk/types": "^3.973.8", + "@smithy/property-provider": "^4.2.14", + "@smithy/shared-ini-file-loader": "^4.4.9", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/credential-providers": { + "version": "3.1033.0", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/client-cognito-identity": "3.1033.0", + "@aws-sdk/core": "^3.974.2", + "@aws-sdk/credential-provider-cognito-identity": "^3.972.25", + "@aws-sdk/credential-provider-env": "^3.972.28", + "@aws-sdk/credential-provider-http": "^3.972.30", + "@aws-sdk/credential-provider-ini": "^3.972.32", + "@aws-sdk/credential-provider-login": "^3.972.32", + "@aws-sdk/credential-provider-node": "^3.972.33", + "@aws-sdk/credential-provider-process": "^3.972.28", + "@aws-sdk/credential-provider-sso": "^3.972.32", + "@aws-sdk/credential-provider-web-identity": "^3.972.32", + "@aws-sdk/nested-clients": "^3.997.0", + "@aws-sdk/types": "^3.973.8", + "@smithy/config-resolver": "^4.4.16", + "@smithy/core": "^3.23.15", + "@smithy/credential-provider-imds": "^4.2.14", + "@smithy/node-config-provider": "^4.3.14", + "@smithy/property-provider": "^4.2.14", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/eventstream-handler-node": { + "version": "3.972.14", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "^3.973.8", + "@smithy/eventstream-codec": "^4.2.14", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/middleware-bucket-endpoint": { + "version": "3.972.10", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "^3.973.8", + "@aws-sdk/util-arn-parser": "^3.972.3", + "@smithy/node-config-provider": "^4.3.14", + "@smithy/protocol-http": "^5.3.14", + "@smithy/types": "^4.14.1", + "@smithy/util-config-provider": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/middleware-eventstream": { + "version": "3.972.10", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "^3.973.8", + "@smithy/protocol-http": "^5.3.14", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/middleware-expect-continue": { + "version": "3.972.10", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "^3.973.8", + "@smithy/protocol-http": "^5.3.14", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/middleware-flexible-checksums": { + "version": "3.974.10", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/crc32": "5.2.0", + "@aws-crypto/crc32c": "5.2.0", + "@aws-crypto/util": "5.2.0", + "@aws-sdk/core": "^3.974.2", + "@aws-sdk/crc64-nvme": "^3.972.7", + "@aws-sdk/types": "^3.973.8", + "@smithy/is-array-buffer": "^4.2.2", + "@smithy/node-config-provider": "^4.3.14", + "@smithy/protocol-http": "^5.3.14", + "@smithy/types": "^4.14.1", + "@smithy/util-middleware": "^4.2.14", + "@smithy/util-stream": "^4.5.23", + "@smithy/util-utf8": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/middleware-host-header": { + "version": "3.972.10", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "^3.973.8", + "@smithy/protocol-http": "^5.3.14", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/middleware-location-constraint": { + "version": "3.972.10", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "^3.973.8", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/middleware-logger": { + "version": "3.972.10", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "^3.973.8", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/middleware-recursion-detection": { + "version": "3.972.11", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "^3.973.8", + "@aws/lambda-invoke-store": "^0.2.2", + "@smithy/protocol-http": "^5.3.14", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/middleware-sdk-s3": { + "version": "3.972.37", + "resolved": "https://registry.npmjs.org/@aws-sdk/middleware-sdk-s3/-/middleware-sdk-s3-3.972.37.tgz", + "integrity": "sha512-Km7M+i8DrLArVzrid1gfxeGhYHBd3uxvE77g0s5a52zPSVosxzQBnJ0gwWb6NIp/DOk8gsBMhi7V+cpJG0ndTA==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "^3.974.8", + "@aws-sdk/types": "^3.973.8", + "@aws-sdk/util-arn-parser": "^3.972.3", + "@smithy/core": "^3.23.17", + "@smithy/node-config-provider": "^4.3.14", + "@smithy/protocol-http": "^5.3.14", + "@smithy/signature-v4": "^5.3.14", + "@smithy/smithy-client": "^4.12.13", + "@smithy/types": "^4.14.1", + "@smithy/util-config-provider": "^4.2.2", + "@smithy/util-middleware": "^4.2.14", + "@smithy/util-stream": "^4.5.25", + "@smithy/util-utf8": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/middleware-ssec": { + "version": "3.972.10", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "^3.973.8", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/middleware-user-agent": { + "version": "3.972.38", + "resolved": "https://registry.npmjs.org/@aws-sdk/middleware-user-agent/-/middleware-user-agent-3.972.38.tgz", + "integrity": "sha512-iz+B29TXcAZsJpwB+AwG/TTGA5l/VnmMZ2UxtiySOZjI6gCdmviXPwdgzcmuazMy16rXoPY4mYCGe7zdNKfx5A==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "^3.974.8", + "@aws-sdk/types": "^3.973.8", + "@aws-sdk/util-endpoints": "^3.996.8", + "@smithy/core": "^3.23.17", + "@smithy/protocol-http": "^5.3.14", + "@smithy/types": "^4.14.1", + "@smithy/util-retry": "^4.3.6", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/middleware-websocket": { + "version": "3.972.16", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "^3.973.8", + "@aws-sdk/util-format-url": "^3.972.10", + "@smithy/eventstream-codec": "^4.2.14", + "@smithy/eventstream-serde-browser": "^4.2.14", + "@smithy/fetch-http-handler": "^5.3.17", + "@smithy/protocol-http": "^5.3.14", + "@smithy/signature-v4": "^5.3.14", + "@smithy/types": "^4.14.1", + "@smithy/util-base64": "^4.3.2", + "@smithy/util-hex-encoding": "^4.2.2", + "@smithy/util-utf8": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">= 14.0.0" + } + }, + "node_modules/@aws-sdk/nested-clients": { + "version": "3.997.6", + "resolved": "https://registry.npmjs.org/@aws-sdk/nested-clients/-/nested-clients-3.997.6.tgz", + "integrity": "sha512-WBDnqatJl+kGObpfmfSxqnXeYTu3Me8wx8WCtvoxX3pfWrrTv8I4WTMSSs7PZqcRcVh8WeUKMgGFjMG+52SR1w==", + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/sha256-browser": "5.2.0", + "@aws-crypto/sha256-js": "5.2.0", + "@aws-sdk/core": "^3.974.8", + "@aws-sdk/middleware-host-header": "^3.972.10", + "@aws-sdk/middleware-logger": "^3.972.10", + "@aws-sdk/middleware-recursion-detection": "^3.972.11", + "@aws-sdk/middleware-user-agent": "^3.972.38", + "@aws-sdk/region-config-resolver": "^3.972.13", + "@aws-sdk/signature-v4-multi-region": "^3.996.25", + "@aws-sdk/types": "^3.973.8", + "@aws-sdk/util-endpoints": "^3.996.8", + "@aws-sdk/util-user-agent-browser": "^3.972.10", + "@aws-sdk/util-user-agent-node": "^3.973.24", + "@smithy/config-resolver": "^4.4.17", + "@smithy/core": "^3.23.17", + "@smithy/fetch-http-handler": "^5.3.17", + "@smithy/hash-node": "^4.2.14", + "@smithy/invalid-dependency": "^4.2.14", + "@smithy/middleware-content-length": "^4.2.14", + "@smithy/middleware-endpoint": "^4.4.32", + "@smithy/middleware-retry": "^4.5.7", + "@smithy/middleware-serde": "^4.2.20", + "@smithy/middleware-stack": "^4.2.14", + "@smithy/node-config-provider": "^4.3.14", + "@smithy/node-http-handler": "^4.6.1", + "@smithy/protocol-http": "^5.3.14", + "@smithy/smithy-client": "^4.12.13", + "@smithy/types": "^4.14.1", + "@smithy/url-parser": "^4.2.14", + "@smithy/util-base64": "^4.3.2", + "@smithy/util-body-length-browser": "^4.2.2", + "@smithy/util-body-length-node": "^4.2.3", + "@smithy/util-defaults-mode-browser": "^4.3.49", + "@smithy/util-defaults-mode-node": "^4.2.54", + "@smithy/util-endpoints": "^3.4.2", + "@smithy/util-middleware": "^4.2.14", + "@smithy/util-retry": "^4.3.6", + "@smithy/util-utf8": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/region-config-resolver": { + "version": "3.972.13", + "resolved": "https://registry.npmjs.org/@aws-sdk/region-config-resolver/-/region-config-resolver-3.972.13.tgz", + "integrity": "sha512-CvJ2ZIjK/jVD/lbOpowBVElJyC1YxLTIJ13yM0AEo0t2v7swOzGjSA6lJGH+DwZXQhcjUjoYwc8bVYCX5MDr1A==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "^3.973.8", + "@smithy/config-resolver": "^4.4.17", + "@smithy/node-config-provider": "^4.3.14", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/signature-v4-multi-region": { + "version": "3.996.25", + "resolved": "https://registry.npmjs.org/@aws-sdk/signature-v4-multi-region/-/signature-v4-multi-region-3.996.25.tgz", + "integrity": "sha512-+CMIt3e1VzlklAECmG+DtP1sV8iKq25FuA0OKpnJ4KA0kxUtd7CgClY7/RU6VzJBQwbN4EJ9Ue6plvqx1qGadw==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/middleware-sdk-s3": "^3.972.37", + "@aws-sdk/types": "^3.973.8", + "@smithy/protocol-http": "^5.3.14", + "@smithy/signature-v4": "^5.3.14", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/token-providers": { + "version": "3.1033.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/core": "^3.974.2", + "@aws-sdk/nested-clients": "^3.997.0", + "@aws-sdk/types": "^3.973.8", + "@smithy/property-provider": "^4.2.14", + "@smithy/shared-ini-file-loader": "^4.4.9", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/types": { + "version": "3.973.8", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/util-arn-parser": { + "version": "3.972.3", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/util-endpoints": { + "version": "3.996.8", + "resolved": "https://registry.npmjs.org/@aws-sdk/util-endpoints/-/util-endpoints-3.996.8.tgz", + "integrity": "sha512-oOZHcRDihk5iEe5V25NVWg45b3qEA8OpHWVdU/XQh8Zj4heVPAJqWvMphQnU7LkufmUo10EpvFPZuQMiFLJK3g==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "^3.973.8", + "@smithy/types": "^4.14.1", + "@smithy/url-parser": "^4.2.14", + "@smithy/util-endpoints": "^3.4.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/util-format-url": { + "version": "3.972.10", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "^3.973.8", + "@smithy/querystring-builder": "^4.2.14", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/util-locate-window": { + "version": "3.965.5", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws-sdk/util-user-agent-browser": { + "version": "3.972.10", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/types": "^3.973.8", + "@smithy/types": "^4.14.1", + "bowser": "^2.11.0", + "tslib": "^2.6.2" + } + }, + "node_modules/@aws-sdk/util-user-agent-node": { + "version": "3.973.24", + "resolved": "https://registry.npmjs.org/@aws-sdk/util-user-agent-node/-/util-user-agent-node-3.973.24.tgz", + "integrity": "sha512-ZWwlkjcIp7cEL8ZfTpTAPNkwx25p7xol0xlKoWVVf22+nsjwmLcHYtTPjIV1cSpmB/b6DaK4cb1fSkvCXHgRdw==", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/middleware-user-agent": "^3.972.38", + "@aws-sdk/types": "^3.973.8", + "@smithy/node-config-provider": "^4.3.14", + "@smithy/types": "^4.14.1", + "@smithy/util-config-provider": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + }, + "peerDependencies": { + "aws-crt": ">=1.0.0" + }, + "peerDependenciesMeta": { + "aws-crt": { + "optional": true + } + } + }, + "node_modules/@aws-sdk/xml-builder": { + "version": "3.972.22", + "resolved": "https://registry.npmjs.org/@aws-sdk/xml-builder/-/xml-builder-3.972.22.tgz", + "integrity": "sha512-PMYKKtJd70IsSG0yHrdAbxBr+ZWBKLvzFZfD3/urxgf6hXVMzuU5M+3MJ5G67RpOmLBu1fAUN65SbWuKUCOlAA==", + "license": "Apache-2.0", + "dependencies": { + "@nodable/entities": "2.1.0", + "@smithy/types": "^4.14.1", + "fast-xml-parser": "5.7.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@aws/bedrock-token-generator": { + "version": "1.1.0", + "resolved": "https://github.com/pgrayy/wasm-deps/releases/download/token-gen-v1.1.0/aws-bedrock-token-generator-1.1.0.tgz", + "integrity": "sha512-5A+Vkyj75mEsBRAQyhRchW3qmNXXG1yKffHwZB8UZ/KYKvK7Wa+/Vq31L8B+pkvTjnnAAW1GhPLtgs9ElgTU6g==", + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/sha256-browser": "^5.2.0", + "@aws-sdk/credential-providers": "^3.525.0", + "@aws-sdk/util-format-url": ">=3.525.0", + "@smithy/config-resolver": "^4.1.4", + "@smithy/hash-node": ">=2.1.3", + "@smithy/invalid-dependency": "^4.0.4", + "@smithy/node-config-provider": "^4.1.3", + "@smithy/protocol-http": ">=3.2.1", + "@smithy/signature-v4": ">=2.1.3", + "@smithy/types": ">=2.11.0" + }, + "engines": { + "node": ">=16.0.0" + } + }, + "node_modules/@aws/lambda-invoke-store": { + "version": "0.2.4", + "license": "Apache-2.0", + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@babel/helper-string-parser": { + "version": "7.27.1", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/helper-validator-identifier": { + "version": "7.28.5", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/parser": { + "version": "7.29.2", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/types": "^7.29.0" + }, + "bin": { + "parser": "bin/babel-parser.js" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@babel/runtime": { + "version": "7.29.2", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/types": { + "version": "7.29.0", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/helper-string-parser": "^7.27.1", + "@babel/helper-validator-identifier": "^7.28.5" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@bcoe/v8-coverage": { + "version": "1.0.2", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + } + }, + "node_modules/@blazediff/core": { + "version": "1.9.1", + "dev": true, + "license": "MIT" + }, + "node_modules/@bytecodealliance/componentize-js": { + "version": "0.20.0", + "dev": true, + "workspaces": [ + "." + ], + "dependencies": { + "@bytecodealliance/jco": "^1.15.1", + "@bytecodealliance/weval": "^0.4.1", + "@bytecodealliance/wizer": "^10.0.0", + "es-module-lexer": "^1.6.0", + "oxc-parser": "^0.76.0" + }, + "bin": { + "componentize-js": "src/cli.js" + } + }, + "node_modules/@bytecodealliance/componentize-js-0-19-3": { + "name": "@bytecodealliance/componentize-js", + "version": "0.19.3", + "dev": true, + "workspaces": [ + "." + ], + "dependencies": { + "@bytecodealliance/jco": "^1.15.1", + "@bytecodealliance/wizer": "^10.0.0", + "es-module-lexer": "^1.6.0", + "oxc-parser": "^0.76.0" + }, + "bin": { + "componentize-js": "src/cli.js" + } + }, + "node_modules/@bytecodealliance/jco": { + "version": "1.18.1", + "dev": true, + "license": "(Apache-2.0 WITH LLVM-exception)", + "dependencies": { + "@bytecodealliance/componentize-js": "^0.20.0", + "@bytecodealliance/componentize-js-0-19-3": "npm:@bytecodealliance/componentize-js@^0.19.3", + "@bytecodealliance/preview2-shim": "^0.17.9", + "binaryen": "^123.0.0", + "commander": "^14", + "mkdirp": "^3", + "ora": "^8", + "terser": "^5" + }, + "bin": { + "jco": "src/jco.js" + } + }, + "node_modules/@bytecodealliance/preview2-shim": { + "version": "0.17.9", + "dev": true, + "license": "(Apache-2.0 WITH LLVM-exception)" + }, + "node_modules/@bytecodealliance/weval": { + "version": "0.4.1", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@napi-rs/lzma": "^1.1.2", + "decompress": "^4.2.1", + "decompress-tar": "^4.1.1", + "decompress-unzip": "^4.0.1" + }, + "engines": { + "node": ">=16" + } + }, + "node_modules/@bytecodealliance/wizer": { + "version": "10.0.0", + "dev": true, + "license": "Apache-2.0", + "bin": { + "wizer": "wizer.js" + }, + "engines": { + "node": ">=16" + }, + "optionalDependencies": { + "@bytecodealliance/wizer-darwin-arm64": "10.0.0", + "@bytecodealliance/wizer-darwin-x64": "10.0.0", + "@bytecodealliance/wizer-linux-arm64": "10.0.0", + "@bytecodealliance/wizer-linux-s390x": "10.0.0", + "@bytecodealliance/wizer-linux-x64": "10.0.0", + "@bytecodealliance/wizer-win32-x64": "10.0.0" + } + }, + "node_modules/@bytecodealliance/wizer-darwin-arm64": { + "version": "10.0.0", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "Apache-2.0", + "optional": true, + "os": [ + "darwin" + ], + "bin": { + "wizer-darwin-arm64": "wizer" + } + }, + "node_modules/@chaynabors/componentize-js": { + "version": "0.19.3", + "dev": true, + "workspaces": [ + "." + ], + "dependencies": { + "@bytecodealliance/jco": "^1.15.1", + "@bytecodealliance/wizer": "^10.0.0", + "es-module-lexer": "^1.6.0", + "oxc-parser": "^0.76.0" + }, + "bin": { + "componentize-js": "src/cli.js" + } + }, + "node_modules/@esbuild/aix-ppc64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.25.12.tgz", + "integrity": "sha512-Hhmwd6CInZ3dwpuGTF8fJG6yoWmsToE+vYgD4nytZVxcu1ulHpUQRAB1UJ8+N1Am3Mz4+xOByoQoSZf4D+CpkA==", + "cpu": [ + "ppc64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "aix" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/android-arm": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.25.12.tgz", + "integrity": "sha512-VJ+sKvNA/GE7Ccacc9Cha7bpS8nyzVv0jdVgwNDaR4gDMC/2TTRc33Ip8qrNYUcpkOHUT5OZ0bUcNNVZQ9RLlg==", + "cpu": [ + "arm" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/android-arm64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.25.12.tgz", + "integrity": "sha512-6AAmLG7zwD1Z159jCKPvAxZd4y/VTO0VkprYy+3N2FtJ8+BQWFXU+OxARIwA46c5tdD9SsKGZ/1ocqBS/gAKHg==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/android-x64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.25.12.tgz", + "integrity": "sha512-5jbb+2hhDHx5phYR2By8GTWEzn6I9UqR11Kwf22iKbNpYrsmRB18aX/9ivc5cabcUiAT/wM+YIZ6SG9QO6a8kg==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/darwin-arm64": { + "version": "0.27.7", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/darwin-x64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.25.12.tgz", + "integrity": "sha512-HQ9ka4Kx21qHXwtlTUVbKJOAnmG1ipXhdWTmNXiPzPfWKpXqASVcWdnf2bnL73wgjNrFXAa3yYvBSd9pzfEIpA==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/freebsd-arm64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.25.12.tgz", + "integrity": "sha512-gA0Bx759+7Jve03K1S0vkOu5Lg/85dou3EseOGUes8flVOGxbhDDh/iZaoek11Y8mtyKPGF3vP8XhnkDEAmzeg==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/freebsd-x64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.25.12.tgz", + "integrity": "sha512-TGbO26Yw2xsHzxtbVFGEXBFH0FRAP7gtcPE7P5yP7wGy7cXK2oO7RyOhL5NLiqTlBh47XhmIUXuGciXEqYFfBQ==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-arm": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.25.12.tgz", + "integrity": "sha512-lPDGyC1JPDou8kGcywY0YILzWlhhnRjdof3UlcoqYmS9El818LLfJJc3PXXgZHrHCAKs/Z2SeZtDJr5MrkxtOw==", + "cpu": [ + "arm" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-arm64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.25.12.tgz", + "integrity": "sha512-8bwX7a8FghIgrupcxb4aUmYDLp8pX06rGh5HqDT7bB+8Rdells6mHvrFHHW2JAOPZUbnjUpKTLg6ECyzvas2AQ==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-ia32": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.25.12.tgz", + "integrity": "sha512-0y9KrdVnbMM2/vG8KfU0byhUN+EFCny9+8g202gYqSSVMonbsCfLjUO+rCci7pM0WBEtz+oK/PIwHkzxkyharA==", + "cpu": [ + "ia32" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-loong64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.25.12.tgz", + "integrity": "sha512-h///Lr5a9rib/v1GGqXVGzjL4TMvVTv+s1DPoxQdz7l/AYv6LDSxdIwzxkrPW438oUXiDtwM10o9PmwS/6Z0Ng==", + "cpu": [ + "loong64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-mips64el": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.25.12.tgz", + "integrity": "sha512-iyRrM1Pzy9GFMDLsXn1iHUm18nhKnNMWscjmp4+hpafcZjrr2WbT//d20xaGljXDBYHqRcl8HnxbX6uaA/eGVw==", + "cpu": [ + "mips64el" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-ppc64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.25.12.tgz", + "integrity": "sha512-9meM/lRXxMi5PSUqEXRCtVjEZBGwB7P/D4yT8UG/mwIdze2aV4Vo6U5gD3+RsoHXKkHCfSxZKzmDssVlRj1QQA==", + "cpu": [ + "ppc64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-riscv64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.25.12.tgz", + "integrity": "sha512-Zr7KR4hgKUpWAwb1f3o5ygT04MzqVrGEGXGLnj15YQDJErYu/BGg+wmFlIDOdJp0PmB0lLvxFIOXZgFRrdjR0w==", + "cpu": [ + "riscv64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-s390x": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.25.12.tgz", + "integrity": "sha512-MsKncOcgTNvdtiISc/jZs/Zf8d0cl/t3gYWX8J9ubBnVOwlk65UIEEvgBORTiljloIWnBzLs4qhzPkJcitIzIg==", + "cpu": [ + "s390x" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/linux-x64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.25.12.tgz", + "integrity": "sha512-uqZMTLr/zR/ed4jIGnwSLkaHmPjOjJvnm6TVVitAa08SLS9Z0VM8wIRx7gWbJB5/J54YuIMInDquWyYvQLZkgw==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/netbsd-arm64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-arm64/-/netbsd-arm64-0.25.12.tgz", + "integrity": "sha512-xXwcTq4GhRM7J9A8Gv5boanHhRa/Q9KLVmcyXHCTaM4wKfIpWkdXiMog/KsnxzJ0A1+nD+zoecuzqPmCRyBGjg==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "netbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/netbsd-x64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.25.12.tgz", + "integrity": "sha512-Ld5pTlzPy3YwGec4OuHh1aCVCRvOXdH8DgRjfDy/oumVovmuSzWfnSJg+VtakB9Cm0gxNO9BzWkj6mtO1FMXkQ==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "netbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/openbsd-arm64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-arm64/-/openbsd-arm64-0.25.12.tgz", + "integrity": "sha512-fF96T6KsBo/pkQI950FARU9apGNTSlZGsv1jZBAlcLL1MLjLNIWPBkj5NlSz8aAzYKg+eNqknrUJ24QBybeR5A==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/openbsd-x64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.25.12.tgz", + "integrity": "sha512-MZyXUkZHjQxUvzK7rN8DJ3SRmrVrke8ZyRusHlP+kuwqTcfWLyqMOE3sScPPyeIXN/mDJIfGXvcMqCgYKekoQw==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/openharmony-arm64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/openharmony-arm64/-/openharmony-arm64-0.25.12.tgz", + "integrity": "sha512-rm0YWsqUSRrjncSXGA7Zv78Nbnw4XL6/dzr20cyrQf7ZmRcsovpcRBdhD43Nuk3y7XIoW2OxMVvwuRvk9XdASg==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openharmony" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/sunos-x64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.25.12.tgz", + "integrity": "sha512-3wGSCDyuTHQUzt0nV7bocDy72r2lI33QL3gkDNGkod22EsYl04sMf0qLb8luNKTOmgF/eDEDP5BFNwoBKH441w==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "sunos" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/win32-arm64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.25.12.tgz", + "integrity": "sha512-rMmLrur64A7+DKlnSuwqUdRKyd3UE7oPJZmnljqEptesKM8wx9J8gx5u0+9Pq0fQQW8vqeKebwNXdfOyP+8Bsg==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/win32-ia32": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.25.12.tgz", + "integrity": "sha512-HkqnmmBoCbCwxUKKNPBixiWDGCpQGVsrQfJoVGYLPT41XWF8lHuE5N6WhVia2n4o5QK5M4tYr21827fNhi4byQ==", + "cpu": [ + "ia32" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@esbuild/win32-x64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.25.12.tgz", + "integrity": "sha512-alJC0uCZpTFrSL0CCDjcgleBXPnCrEAhTBILpeAp7M/OFgoqtAetfBzX0xM00MUsVVPpVjlPuMbREqnZCXaTnA==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/@eslint-community/eslint-utils": { + "version": "4.9.1", + "dev": true, + "license": "MIT", + "dependencies": { + "eslint-visitor-keys": "^3.4.3" + }, + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + }, + "peerDependencies": { + "eslint": "^6.0.0 || ^7.0.0 || >=8.0.0" + } + }, + "node_modules/@eslint-community/regexpp": { + "version": "4.12.2", + "dev": true, + "license": "MIT", + "engines": { + "node": "^12.0.0 || ^14.0.0 || >=16.0.0" + } + }, + "node_modules/@eslint/config-array": { + "version": "0.23.5", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@eslint/object-schema": "^3.0.5", + "debug": "^4.3.1", + "minimatch": "^10.2.4" + }, + "engines": { + "node": "^20.19.0 || ^22.13.0 || >=24" + } + }, + "node_modules/@eslint/config-helpers": { + "version": "0.5.5", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@eslint/core": "^1.2.1" + }, + "engines": { + "node": "^20.19.0 || ^22.13.0 || >=24" + } + }, + "node_modules/@eslint/core": { + "version": "1.2.1", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@types/json-schema": "^7.0.15" + }, + "engines": { + "node": "^20.19.0 || ^22.13.0 || >=24" + } + }, + "node_modules/@eslint/js": { + "version": "9.39.4", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "url": "https://eslint.org/donate" + } + }, + "node_modules/@eslint/object-schema": { + "version": "3.0.5", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^20.19.0 || ^22.13.0 || >=24" + } + }, + "node_modules/@eslint/plugin-kit": { + "version": "0.7.1", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@eslint/core": "^1.2.1", + "levn": "^0.4.1" + }, + "engines": { + "node": "^20.19.0 || ^22.13.0 || >=24" + } + }, + "node_modules/@google/genai": { + "version": "1.50.1", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "google-auth-library": "^10.3.0", + "p-retry": "^4.6.2", + "protobufjs": "^7.5.4", + "ws": "^8.18.0" + }, + "engines": { + "node": ">=20.0.0" + }, + "peerDependencies": { + "@modelcontextprotocol/sdk": "^1.25.2" + }, + "peerDependenciesMeta": { + "@modelcontextprotocol/sdk": { + "optional": true + } + } + }, + "node_modules/@hono/node-server": { + "version": "1.19.14", + "license": "MIT", + "peer": true, + "engines": { + "node": ">=18.14.1" + }, + "peerDependencies": { + "hono": "^4" + } + }, + "node_modules/@humanfs/core": { + "version": "0.19.2", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@humanfs/types": "^0.15.0" + }, + "engines": { + "node": ">=18.18.0" + } + }, + "node_modules/@humanfs/node": { + "version": "0.16.8", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@humanfs/core": "^0.19.2", + "@humanfs/types": "^0.15.0", + "@humanwhocodes/retry": "^0.4.0" + }, + "engines": { + "node": ">=18.18.0" + } + }, + "node_modules/@humanfs/types": { + "version": "0.15.0", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=18.18.0" + } + }, + "node_modules/@humanwhocodes/module-importer": { + "version": "1.0.1", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=12.22" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/nzakas" + } + }, + "node_modules/@humanwhocodes/retry": { + "version": "0.4.3", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=18.18" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/nzakas" + } + }, + "node_modules/@jridgewell/gen-mapping": { + "version": "0.3.13", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/sourcemap-codec": "^1.5.0", + "@jridgewell/trace-mapping": "^0.3.24" + } + }, + "node_modules/@jridgewell/resolve-uri": { + "version": "3.1.2", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/source-map": { + "version": "0.3.11", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/gen-mapping": "^0.3.5", + "@jridgewell/trace-mapping": "^0.3.25" + } + }, + "node_modules/@jridgewell/sourcemap-codec": { + "version": "1.5.5", + "dev": true, + "license": "MIT" + }, + "node_modules/@jridgewell/trace-mapping": { + "version": "0.3.31", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/resolve-uri": "^3.1.0", + "@jridgewell/sourcemap-codec": "^1.4.14" + } + }, + "node_modules/@microsoft/tsdoc": { + "version": "0.16.0", + "dev": true, + "license": "MIT" + }, + "node_modules/@microsoft/tsdoc-config": { + "version": "0.18.1", + "dev": true, + "license": "MIT", + "dependencies": { + "@microsoft/tsdoc": "0.16.0", + "ajv": "~8.18.0", + "jju": "~1.4.0", + "resolve": "~1.22.2" + } + }, + "node_modules/@modelcontextprotocol/sdk": { + "version": "1.29.0", + "license": "MIT", + "peer": true, + "dependencies": { + "@hono/node-server": "^1.19.9", + "ajv": "^8.17.1", + "ajv-formats": "^3.0.1", + "content-type": "^1.0.5", + "cors": "^2.8.5", + "cross-spawn": "^7.0.5", + "eventsource": "^3.0.2", + "eventsource-parser": "^3.0.0", + "express": "^5.2.1", + "express-rate-limit": "^8.2.1", + "hono": "^4.11.4", + "jose": "^6.1.3", + "json-schema-typed": "^8.0.2", + "pkce-challenge": "^5.0.0", + "raw-body": "^3.0.0", + "zod": "^3.25 || ^4.0", + "zod-to-json-schema": "^3.25.1" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "@cfworker/json-schema": "^4.1.1", + "zod": "^3.25 || ^4.0" + }, + "peerDependenciesMeta": { + "@cfworker/json-schema": { + "optional": true + }, + "zod": { + "optional": false + } + } + }, + "node_modules/@napi-rs/lzma": { + "version": "1.4.5", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 10" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Brooooooklyn" + }, + "optionalDependencies": { + "@napi-rs/lzma-android-arm-eabi": "1.4.5", + "@napi-rs/lzma-android-arm64": "1.4.5", + "@napi-rs/lzma-darwin-arm64": "1.4.5", + "@napi-rs/lzma-darwin-x64": "1.4.5", + "@napi-rs/lzma-freebsd-x64": "1.4.5", + "@napi-rs/lzma-linux-arm-gnueabihf": "1.4.5", + "@napi-rs/lzma-linux-arm64-gnu": "1.4.5", + "@napi-rs/lzma-linux-arm64-musl": "1.4.5", + "@napi-rs/lzma-linux-ppc64-gnu": "1.4.5", + "@napi-rs/lzma-linux-riscv64-gnu": "1.4.5", + "@napi-rs/lzma-linux-s390x-gnu": "1.4.5", + "@napi-rs/lzma-linux-x64-gnu": "1.4.5", + "@napi-rs/lzma-linux-x64-musl": "1.4.5", + "@napi-rs/lzma-wasm32-wasi": "1.4.5", + "@napi-rs/lzma-win32-arm64-msvc": "1.4.5", + "@napi-rs/lzma-win32-ia32-msvc": "1.4.5", + "@napi-rs/lzma-win32-x64-msvc": "1.4.5" + } + }, + "node_modules/@napi-rs/lzma-darwin-arm64": { + "version": "1.4.5", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@nodable/entities": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/@nodable/entities/-/entities-2.1.0.tgz", + "integrity": "sha512-nyT7T3nbMyBI/lvr6L5TyWbFJAI9FTgVRakNoBqCD+PmID8DzFrrNdLLtHMwMszOtqZa8PAOV24ZqDnQrhQINA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/nodable" + } + ], + "license": "MIT" + }, + "node_modules/@opentelemetry/api": { + "version": "1.9.1", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=8.0.0" + } + }, + "node_modules/@opentelemetry/api-logs": { + "version": "0.214.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/api": "^1.3.0" + }, + "engines": { + "node": ">=8.0.0" + } + }, + "node_modules/@opentelemetry/context-async-hooks": { + "version": "2.7.0", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": ">=1.0.0 <1.10.0" + } + }, + "node_modules/@opentelemetry/core": { + "version": "2.6.1", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/semantic-conventions": "^1.29.0" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": ">=1.0.0 <1.10.0" + } + }, + "node_modules/@opentelemetry/exporter-metrics-otlp-http": { + "version": "0.214.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/core": "2.6.1", + "@opentelemetry/otlp-exporter-base": "0.214.0", + "@opentelemetry/otlp-transformer": "0.214.0", + "@opentelemetry/resources": "2.6.1", + "@opentelemetry/sdk-metrics": "2.6.1" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": "^1.3.0" + } + }, + "node_modules/@opentelemetry/exporter-metrics-otlp-http/node_modules/@opentelemetry/sdk-metrics": { + "version": "2.6.1", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/core": "2.6.1", + "@opentelemetry/resources": "2.6.1" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": ">=1.9.0 <1.10.0" + } + }, + "node_modules/@opentelemetry/exporter-trace-otlp-http": { + "version": "0.214.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/core": "2.6.1", + "@opentelemetry/otlp-exporter-base": "0.214.0", + "@opentelemetry/otlp-transformer": "0.214.0", + "@opentelemetry/resources": "2.6.1", + "@opentelemetry/sdk-trace-base": "2.6.1" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": "^1.3.0" + } + }, + "node_modules/@opentelemetry/otlp-exporter-base": { + "version": "0.214.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/core": "2.6.1", + "@opentelemetry/otlp-transformer": "0.214.0" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": "^1.3.0" + } + }, + "node_modules/@opentelemetry/otlp-transformer": { + "version": "0.214.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/api-logs": "0.214.0", + "@opentelemetry/core": "2.6.1", + "@opentelemetry/resources": "2.6.1", + "@opentelemetry/sdk-logs": "0.214.0", + "@opentelemetry/sdk-metrics": "2.6.1", + "@opentelemetry/sdk-trace-base": "2.6.1", + "protobufjs": "^7.0.0" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": "^1.3.0" + } + }, + "node_modules/@opentelemetry/otlp-transformer/node_modules/@opentelemetry/sdk-metrics": { + "version": "2.6.1", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/core": "2.6.1", + "@opentelemetry/resources": "2.6.1" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": ">=1.9.0 <1.10.0" + } + }, + "node_modules/@opentelemetry/resources": { + "version": "2.6.1", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/core": "2.6.1", + "@opentelemetry/semantic-conventions": "^1.29.0" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": ">=1.3.0 <1.10.0" + } + }, + "node_modules/@opentelemetry/sdk-logs": { + "version": "0.214.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/api-logs": "0.214.0", + "@opentelemetry/core": "2.6.1", + "@opentelemetry/resources": "2.6.1", + "@opentelemetry/semantic-conventions": "^1.29.0" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": ">=1.4.0 <1.10.0" + } + }, + "node_modules/@opentelemetry/sdk-metrics": { + "version": "2.7.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/core": "2.7.0", + "@opentelemetry/resources": "2.7.0" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": ">=1.9.0 <1.10.0" + } + }, + "node_modules/@opentelemetry/sdk-metrics/node_modules/@opentelemetry/core": { + "version": "2.7.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/semantic-conventions": "^1.29.0" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": ">=1.0.0 <1.10.0" + } + }, + "node_modules/@opentelemetry/sdk-metrics/node_modules/@opentelemetry/resources": { + "version": "2.7.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/core": "2.7.0", + "@opentelemetry/semantic-conventions": "^1.29.0" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": ">=1.3.0 <1.10.0" + } + }, + "node_modules/@opentelemetry/sdk-trace-base": { + "version": "2.6.1", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/core": "2.6.1", + "@opentelemetry/resources": "2.6.1", + "@opentelemetry/semantic-conventions": "^1.29.0" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": ">=1.3.0 <1.10.0" + } + }, + "node_modules/@opentelemetry/sdk-trace-node": { + "version": "2.7.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/context-async-hooks": "2.7.0", + "@opentelemetry/core": "2.7.0", + "@opentelemetry/sdk-trace-base": "2.7.0" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": ">=1.0.0 <1.10.0" + } + }, + "node_modules/@opentelemetry/sdk-trace-node/node_modules/@opentelemetry/core": { + "version": "2.7.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/semantic-conventions": "^1.29.0" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": ">=1.0.0 <1.10.0" + } + }, + "node_modules/@opentelemetry/sdk-trace-node/node_modules/@opentelemetry/resources": { + "version": "2.7.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/core": "2.7.0", + "@opentelemetry/semantic-conventions": "^1.29.0" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": ">=1.3.0 <1.10.0" + } + }, + "node_modules/@opentelemetry/sdk-trace-node/node_modules/@opentelemetry/sdk-trace-base": { + "version": "2.7.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/core": "2.7.0", + "@opentelemetry/resources": "2.7.0", + "@opentelemetry/semantic-conventions": "^1.29.0" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": ">=1.3.0 <1.10.0" + } + }, + "node_modules/@opentelemetry/semantic-conventions": { + "version": "1.40.0", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=14" + } + }, + "node_modules/@oxc-parser/binding-darwin-arm64": { + "version": "0.76.0", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@oxc-project/types": { + "version": "0.76.0", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/Boshen" + } + }, + "node_modules/@polka/url": { + "version": "1.0.0-next.29", + "dev": true, + "license": "MIT" + }, + "node_modules/@protobufjs/aspromise": { + "version": "1.1.2", + "dev": true, + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/base64": { + "version": "1.1.2", + "dev": true, + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/codegen": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@protobufjs/codegen/-/codegen-2.0.5.tgz", + "integrity": "sha512-zgXFLzW3Ap33e6d0Wlj4MGIm6Ce8O89n/apUaGNB/jx+hw+ruWEp7EwGUshdLKVRCxZW12fp9r40E1mQrf/34g==", + "dev": true, + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/eventemitter": { + "version": "1.1.0", + "dev": true, + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/fetch": { + "version": "1.1.0", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "@protobufjs/aspromise": "^1.1.1", + "@protobufjs/inquire": "^1.1.0" + } + }, + "node_modules/@protobufjs/float": { + "version": "1.0.2", + "dev": true, + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/inquire": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@protobufjs/inquire/-/inquire-1.1.1.tgz", + "integrity": "sha512-mnzgDV26ueAvk7rsbt9L7bE0SuAoqyuys/sMMrmVcN5x9VsxpcG3rqAUSgDyLp0UZlmNfIbQ4fHfCtreVBk8Ew==", + "dev": true, + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/path": { + "version": "1.1.2", + "dev": true, + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/pool": { + "version": "1.1.0", + "dev": true, + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/utf8": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@protobufjs/utf8/-/utf8-1.1.1.tgz", + "integrity": "sha512-oOAWABowe8EAbMyWKM0tYDKi8Yaox52D+HWZhAIJqQXbqe0xI/GV7FhLWqlEKreMkfDjshR5FKgi3mnle0h6Eg==", + "dev": true, + "license": "BSD-3-Clause" + }, + "node_modules/@rollup/rollup-android-arm-eabi": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.60.2.tgz", + "integrity": "sha512-dnlp69efPPg6Uaw2dVqzWRfAWRnYVb1XJ8CyyhIbZeaq4CA5/mLeZ1IEt9QqQxmbdvagjLIm2ZL8BxXv5lH4Yw==", + "cpu": [ + "arm" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ] + }, + "node_modules/@rollup/rollup-android-arm64": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.60.2.tgz", + "integrity": "sha512-OqZTwDRDchGRHHm/hwLOL7uVPB9aUvI0am/eQuWMNyFHf5PSEQmyEeYYheA0EPPKUO/l0uigCp+iaTjoLjVoHg==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "android" + ] + }, + "node_modules/@rollup/rollup-darwin-arm64": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.60.2.tgz", + "integrity": "sha512-UwRE7CGpvSVEQS8gUMBe1uADWjNnVgP3Iusyda1nSRwNDCsRjnGc7w6El6WLQsXmZTbLZx9cecegumcitNfpmA==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/@rollup/rollup-darwin-x64": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.60.2.tgz", + "integrity": "sha512-gjEtURKLCC5VXm1I+2i1u9OhxFsKAQJKTVB8WvDAHF+oZlq0GTVFOlTlO1q3AlCTE/DF32c16ESvfgqR7343/g==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/@rollup/rollup-freebsd-arm64": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.60.2.tgz", + "integrity": "sha512-Bcl6CYDeAgE70cqZaMojOi/eK63h5Me97ZqAQoh77VPjMysA/4ORQBRGo3rRy45x4MzVlU9uZxs8Uwy7ZaKnBw==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ] + }, + "node_modules/@rollup/rollup-freebsd-x64": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.60.2.tgz", + "integrity": "sha512-LU+TPda3mAE2QB0/Hp5VyeKJivpC6+tlOXd1VMoXV/YFMvk/MNk5iXeBfB4MQGRWyOYVJ01625vjkr0Az98OJQ==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ] + }, + "node_modules/@rollup/rollup-linux-arm-gnueabihf": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.60.2.tgz", + "integrity": "sha512-2QxQrM+KQ7DAW4o22j+XZ6RKdxjLD7BOWTP0Bv0tmjdyhXSsr2Ul1oJDQqh9Zf5qOwTuTc7Ek83mOFaKnodPjg==", + "cpu": [ + "arm" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-arm-musleabihf": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.60.2.tgz", + "integrity": "sha512-TbziEu2DVsTEOPif2mKWkMeDMLoYjx95oESa9fkQQK7r/Orta0gnkcDpzwufEcAO2BLBsD7mZkXGFqEdMRRwfw==", + "cpu": [ + "arm" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-arm64-gnu": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.60.2.tgz", + "integrity": "sha512-bO/rVDiDUuM2YfuCUwZ1t1cP+/yqjqz+Xf2VtkdppefuOFS2OSeAfgafaHNkFn0t02hEyXngZkxtGqXcXwO8Rg==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-arm64-musl": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.60.2.tgz", + "integrity": "sha512-hr26p7e93Rl0Za+JwW7EAnwAvKkehh12BU1Llm9Ykiibg4uIr2rbpxG9WCf56GuvidlTG9KiiQT/TXT1yAWxTA==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-loong64-gnu": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.60.2.tgz", + "integrity": "sha512-pOjB/uSIyDt+ow3k/RcLvUAOGpysT2phDn7TTUB3n75SlIgZzM6NKAqlErPhoFU+npgY3/n+2HYIQVbF70P9/A==", + "cpu": [ + "loong64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-loong64-musl": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.60.2.tgz", + "integrity": "sha512-2/w+q8jszv9Ww1c+6uJT3OwqhdmGP2/4T17cu8WuwyUuuaCDDJ2ojdyYwZzCxx0GcsZBhzi3HmH+J5pZNXnd+Q==", + "cpu": [ + "loong64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-ppc64-gnu": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.60.2.tgz", + "integrity": "sha512-11+aL5vKheYgczxtPVVRhdptAM2H7fcDR5Gw4/bTcteuZBlH4oP9f5s9zYO9aGZvoGeBpqXI/9TZZihZ609wKw==", + "cpu": [ + "ppc64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-ppc64-musl": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.60.2.tgz", + "integrity": "sha512-i16fokAGK46IVZuV8LIIwMdtqhin9hfYkCh8pf8iC3QU3LpwL+1FSFGej+O7l3E/AoknL6Dclh2oTdnRMpTzFQ==", + "cpu": [ + "ppc64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-riscv64-gnu": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.60.2.tgz", + "integrity": "sha512-49FkKS6RGQoriDSK/6E2GkAsAuU5kETFCh7pG4yD/ylj9rKhTmO3elsnmBvRD4PgJPds5W2PkhC82aVwmUcJ7A==", + "cpu": [ + "riscv64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-riscv64-musl": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.60.2.tgz", + "integrity": "sha512-mjYNkHPfGpUR00DuM1ZZIgs64Hpf4bWcz9Z41+4Q+pgDx73UwWdAYyf6EG/lRFldmdHHzgrYyge5akFUW0D3mQ==", + "cpu": [ + "riscv64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-s390x-gnu": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.60.2.tgz", + "integrity": "sha512-ALyvJz965BQk8E9Al/JDKKDLH2kfKFLTGMlgkAbbYtZuJt9LU8DW3ZoDMCtQpXAltZxwBHevXz5u+gf0yA0YoA==", + "cpu": [ + "s390x" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-x64-gnu": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.60.2.tgz", + "integrity": "sha512-UQjrkIdWrKI626Du8lCQ6MJp/6V1LAo2bOK9OTu4mSn8GGXIkPXk/Vsp4bLHCd9Z9Iz2OTEaokUE90VweJgIYQ==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-x64-musl": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.60.2.tgz", + "integrity": "sha512-bTsRGj6VlSdn/XD4CGyzMnzaBs9bsRxy79eTqTCBsA8TMIEky7qg48aPkvJvFe1HyzQ5oMZdg7AnVlWQSKLTnw==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-openbsd-x64": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.60.2.tgz", + "integrity": "sha512-6d4Z3534xitaA1FcMWP7mQPq5zGwBmGbhphh2DwaA1aNIXUu3KTOfwrWpbwI4/Gr0uANo7NTtaykFyO2hPuFLg==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ] + }, + "node_modules/@rollup/rollup-openharmony-arm64": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.60.2.tgz", + "integrity": "sha512-NetAg5iO2uN7eB8zE5qrZ3CSil+7IJt4WDFLcC75Ymywq1VZVD6qJ6EvNLjZ3rEm6gB7XW5JdT60c6MN35Z85Q==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "openharmony" + ] + }, + "node_modules/@rollup/rollup-win32-arm64-msvc": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.60.2.tgz", + "integrity": "sha512-NCYhOotpgWZ5kdxCZsv6Iudx0wX8980Q/oW4pNFNihpBKsDbEA1zpkfxJGC0yugsUuyDZ7gL37dbzwhR0VI7pQ==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@rollup/rollup-win32-ia32-msvc": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.60.2.tgz", + "integrity": "sha512-RXsaOqXxfoUBQoOgvmmijVxJnW2IGB0eoMO7F8FAjaj0UTywUO/luSqimWBJn04WNgUkeNhh7fs7pESXajWmkg==", + "cpu": [ + "ia32" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@rollup/rollup-win32-x64-gnu": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.60.2.tgz", + "integrity": "sha512-qdAzEULD+/hzObedtmV6iBpdL5TIbKVztGiK7O3/KYSf+HIzU257+MX1EXJcyIiDbMAqmbwaufcYPvyRryeZtA==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@rollup/rollup-win32-x64-msvc": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.60.2.tgz", + "integrity": "sha512-Nd/SgG27WoA9e+/TdK74KnHz852TLa94ovOYySo/yMPuTmpckK/jIF2jSwS3g7ELSKXK13/cVdmg1Z/DaCWKxA==", + "cpu": [ + "x64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@smithy/chunked-blob-reader": { + "version": "5.2.2", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/chunked-blob-reader-native": { + "version": "4.2.3", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@smithy/util-base64": "^4.3.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/config-resolver": { + "version": "4.4.17", + "license": "Apache-2.0", + "dependencies": { + "@smithy/node-config-provider": "^4.3.14", + "@smithy/types": "^4.14.1", + "@smithy/util-config-provider": "^4.2.2", + "@smithy/util-endpoints": "^3.4.2", + "@smithy/util-middleware": "^4.2.14", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/core": { + "version": "3.23.17", + "resolved": "https://registry.npmjs.org/@smithy/core/-/core-3.23.17.tgz", + "integrity": "sha512-x7BlLbUFL8NWCGjMF9C+1N5cVCxcPa7g6Tv9B4A2luWx3be3oU8hQ96wIwxe/s7OhIzvoJH73HAUSg5JXVlEtQ==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/protocol-http": "^5.3.14", + "@smithy/types": "^4.14.1", + "@smithy/url-parser": "^4.2.14", + "@smithy/util-base64": "^4.3.2", + "@smithy/util-body-length-browser": "^4.2.2", + "@smithy/util-middleware": "^4.2.14", + "@smithy/util-stream": "^4.5.25", + "@smithy/util-utf8": "^4.2.2", + "@smithy/uuid": "^1.1.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/credential-provider-imds": { + "version": "4.2.14", + "license": "Apache-2.0", + "dependencies": { + "@smithy/node-config-provider": "^4.3.14", + "@smithy/property-provider": "^4.2.14", + "@smithy/types": "^4.14.1", + "@smithy/url-parser": "^4.2.14", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/eventstream-codec": { + "version": "4.2.14", + "license": "Apache-2.0", + "dependencies": { + "@aws-crypto/crc32": "5.2.0", + "@smithy/types": "^4.14.1", + "@smithy/util-hex-encoding": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/eventstream-serde-browser": { + "version": "4.2.14", + "license": "Apache-2.0", + "dependencies": { + "@smithy/eventstream-serde-universal": "^4.2.14", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/eventstream-serde-config-resolver": { + "version": "4.3.14", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/eventstream-serde-node": { + "version": "4.2.14", + "license": "Apache-2.0", + "dependencies": { + "@smithy/eventstream-serde-universal": "^4.2.14", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/eventstream-serde-universal": { + "version": "4.2.14", + "license": "Apache-2.0", + "dependencies": { + "@smithy/eventstream-codec": "^4.2.14", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/fetch-http-handler": { + "version": "5.3.17", + "license": "Apache-2.0", + "dependencies": { + "@smithy/protocol-http": "^5.3.14", + "@smithy/querystring-builder": "^4.2.14", + "@smithy/types": "^4.14.1", + "@smithy/util-base64": "^4.3.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/hash-blob-browser": { + "version": "4.2.15", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@smithy/chunked-blob-reader": "^5.2.2", + "@smithy/chunked-blob-reader-native": "^4.2.3", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/hash-node": { + "version": "4.2.14", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.14.1", + "@smithy/util-buffer-from": "^4.2.2", + "@smithy/util-utf8": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/hash-stream-node": { + "version": "4.2.14", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.14.1", + "@smithy/util-utf8": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/invalid-dependency": { + "version": "4.2.14", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/is-array-buffer": { + "version": "4.2.2", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/md5-js": { + "version": "4.2.14", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.14.1", + "@smithy/util-utf8": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/middleware-content-length": { + "version": "4.2.14", + "license": "Apache-2.0", + "dependencies": { + "@smithy/protocol-http": "^5.3.14", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/middleware-endpoint": { + "version": "4.4.32", + "resolved": "https://registry.npmjs.org/@smithy/middleware-endpoint/-/middleware-endpoint-4.4.32.tgz", + "integrity": "sha512-ZZkgyjnJppiZbIm6Qbx92pbXYi1uzenIvGhBSCDlc7NwuAkiqSgS75j1czAD25ZLs2FjMjYy1q7gyRVWG6JA0Q==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/core": "^3.23.17", + "@smithy/middleware-serde": "^4.2.20", + "@smithy/node-config-provider": "^4.3.14", + "@smithy/shared-ini-file-loader": "^4.4.9", + "@smithy/types": "^4.14.1", + "@smithy/url-parser": "^4.2.14", + "@smithy/util-middleware": "^4.2.14", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/middleware-retry": { + "version": "4.5.7", + "resolved": "https://registry.npmjs.org/@smithy/middleware-retry/-/middleware-retry-4.5.7.tgz", + "integrity": "sha512-bRt6ZImqVSeTk39Nm81K20ObIiAZ3WefY7G6+iz/0tZjs4dgRRjvRX2sgsH+zi6iDCRR/aQvQofLKxxz4rPBZg==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/core": "^3.23.17", + "@smithy/node-config-provider": "^4.3.14", + "@smithy/protocol-http": "^5.3.14", + "@smithy/service-error-classification": "^4.3.1", + "@smithy/smithy-client": "^4.12.13", + "@smithy/types": "^4.14.1", + "@smithy/util-middleware": "^4.2.14", + "@smithy/util-retry": "^4.3.6", + "@smithy/uuid": "^1.1.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/middleware-serde": { + "version": "4.2.20", + "resolved": "https://registry.npmjs.org/@smithy/middleware-serde/-/middleware-serde-4.2.20.tgz", + "integrity": "sha512-Lx9JMO9vArPtiChE3wbEZ5akMIDQpWQtlu90lhACQmNOXcGXRbaDywMHDzuDZ2OkZzP+9wQfZi3YJT9F67zTQQ==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/core": "^3.23.17", + "@smithy/protocol-http": "^5.3.14", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/middleware-stack": { + "version": "4.2.14", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/node-config-provider": { + "version": "4.3.14", + "license": "Apache-2.0", + "dependencies": { + "@smithy/property-provider": "^4.2.14", + "@smithy/shared-ini-file-loader": "^4.4.9", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/node-http-handler": { + "version": "4.6.1", + "resolved": "https://registry.npmjs.org/@smithy/node-http-handler/-/node-http-handler-4.6.1.tgz", + "integrity": "sha512-iB+orM4x3xrr57X3YaXazfKnntl0LHlZB1kcXSGzMV1Tt0+YwEjGlbjk/44qEGtBzXAz6yFDzkYTKSV6Pj2HUg==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/protocol-http": "^5.3.14", + "@smithy/querystring-builder": "^4.2.14", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/property-provider": { + "version": "4.2.14", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/protocol-http": { + "version": "5.3.14", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/querystring-builder": { + "version": "4.2.14", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.14.1", + "@smithy/util-uri-escape": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/querystring-parser": { + "version": "4.2.14", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/service-error-classification": { + "version": "4.3.1", + "resolved": "https://registry.npmjs.org/@smithy/service-error-classification/-/service-error-classification-4.3.1.tgz", + "integrity": "sha512-aUQuDGh760ts/8MU+APjIZhlLPKhIIfqyzZaJikLEIMrdxFvxuLYD0WxWzaYWpmLbQlXDe9p7EWM3HsBe0K6Gw==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.14.1" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/shared-ini-file-loader": { + "version": "4.4.9", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/signature-v4": { + "version": "5.3.14", + "license": "Apache-2.0", + "dependencies": { + "@smithy/is-array-buffer": "^4.2.2", + "@smithy/protocol-http": "^5.3.14", + "@smithy/types": "^4.14.1", + "@smithy/util-hex-encoding": "^4.2.2", + "@smithy/util-middleware": "^4.2.14", + "@smithy/util-uri-escape": "^4.2.2", + "@smithy/util-utf8": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/smithy-client": { + "version": "4.12.13", + "resolved": "https://registry.npmjs.org/@smithy/smithy-client/-/smithy-client-4.12.13.tgz", + "integrity": "sha512-y/Pcj1V9+qG98gyu1gvftHB7rDpdh+7kIBIggs55yGm3JdtBV8GT8IFF3a1qxZ79QnaJHX9GXzvBG6tAd+czJA==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/core": "^3.23.17", + "@smithy/middleware-endpoint": "^4.4.32", + "@smithy/middleware-stack": "^4.2.14", + "@smithy/protocol-http": "^5.3.14", + "@smithy/types": "^4.14.1", + "@smithy/util-stream": "^4.5.25", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/types": { + "version": "4.14.1", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/url-parser": { + "version": "4.2.14", + "license": "Apache-2.0", + "dependencies": { + "@smithy/querystring-parser": "^4.2.14", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-base64": { + "version": "4.3.2", + "license": "Apache-2.0", + "dependencies": { + "@smithy/util-buffer-from": "^4.2.2", + "@smithy/util-utf8": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-body-length-browser": { + "version": "4.2.2", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-body-length-node": { + "version": "4.2.3", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-buffer-from": { + "version": "4.2.2", + "license": "Apache-2.0", + "dependencies": { + "@smithy/is-array-buffer": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-config-provider": { + "version": "4.2.2", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-defaults-mode-browser": { + "version": "4.3.49", + "resolved": "https://registry.npmjs.org/@smithy/util-defaults-mode-browser/-/util-defaults-mode-browser-4.3.49.tgz", + "integrity": "sha512-a5bNrdiONYB/qE2BuKegvUMd/+ZDwdg4vsNuuSzYE8qs2EYAdK9CynL+Rzn29PbPiUqoz/cbpRbcLzD5lEevHw==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/property-provider": "^4.2.14", + "@smithy/smithy-client": "^4.12.13", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-defaults-mode-node": { + "version": "4.2.54", + "resolved": "https://registry.npmjs.org/@smithy/util-defaults-mode-node/-/util-defaults-mode-node-4.2.54.tgz", + "integrity": "sha512-g1cvrJvOnzeJgEdf7AE4luI7gp6L8weE0y9a9wQUSGtjb8QRHDbCJYuE4Sy0SD9N8RrnNPFsPltAz/OSoBR9Zw==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/config-resolver": "^4.4.17", + "@smithy/credential-provider-imds": "^4.2.14", + "@smithy/node-config-provider": "^4.3.14", + "@smithy/property-provider": "^4.2.14", + "@smithy/smithy-client": "^4.12.13", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-endpoints": { + "version": "3.4.2", + "license": "Apache-2.0", + "dependencies": { + "@smithy/node-config-provider": "^4.3.14", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-hex-encoding": { + "version": "4.2.2", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-middleware": { + "version": "4.2.14", + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-retry": { + "version": "4.3.8", + "resolved": "https://registry.npmjs.org/@smithy/util-retry/-/util-retry-4.3.8.tgz", + "integrity": "sha512-LUIxbTBi+OpvXpg91poGA6BdyoleMDLnfXjVDqyi2RvZmTveY5loE/FgYUBCR5LU2BThW2SoZRh8dTIIy38IPw==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/service-error-classification": "^4.3.1", + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-stream": { + "version": "4.5.25", + "resolved": "https://registry.npmjs.org/@smithy/util-stream/-/util-stream-4.5.25.tgz", + "integrity": "sha512-/PFpG4k8Ze8Ei+mMKj3oiPICYekthuzePZMgZbCqMiXIHHf4n2aZ4Ps0aSRShycFTGuj/J6XldmC0x0DwednIA==", + "license": "Apache-2.0", + "dependencies": { + "@smithy/fetch-http-handler": "^5.3.17", + "@smithy/node-http-handler": "^4.6.1", + "@smithy/types": "^4.14.1", + "@smithy/util-base64": "^4.3.2", + "@smithy/util-buffer-from": "^4.2.2", + "@smithy/util-hex-encoding": "^4.2.2", + "@smithy/util-utf8": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-uri-escape": { + "version": "4.2.2", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-utf8": { + "version": "4.2.2", + "license": "Apache-2.0", + "dependencies": { + "@smithy/util-buffer-from": "^4.2.2", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/util-waiter": { + "version": "4.2.16", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@smithy/types": "^4.14.1", + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@smithy/uuid": { + "version": "1.1.2", + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.6.2" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/@standard-schema/spec": { + "version": "1.1.0", + "dev": true, + "license": "MIT" + }, + "node_modules/@strands-agents/sdk": { + "resolved": "strands-ts", + "link": true + }, + "node_modules/@strands-agents/strandly": { + "resolved": "strandly", + "link": true + }, + "node_modules/@strands-agents/wasm": { + "resolved": "strands-wasm", + "link": true + }, + "node_modules/@types/body-parser": { + "version": "1.19.6", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/connect": "*", + "@types/node": "*" + } + }, + "node_modules/@types/chai": { + "version": "5.2.3", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/deep-eql": "*", + "assertion-error": "^2.0.1" + } + }, + "node_modules/@types/connect": { + "version": "3.4.38", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, + "node_modules/@types/deep-eql": { + "version": "4.0.2", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/esrecurse": { + "version": "4.3.1", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/estree": { + "version": "1.0.8", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/express": { + "version": "5.0.6", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/body-parser": "*", + "@types/express-serve-static-core": "^5.0.0", + "@types/serve-static": "^2" + } + }, + "node_modules/@types/express-serve-static-core": { + "version": "5.1.1", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*", + "@types/qs": "*", + "@types/range-parser": "*", + "@types/send": "*" + } + }, + "node_modules/@types/http-errors": { + "version": "2.0.5", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/json-schema": { + "version": "7.0.15", + "license": "MIT" + }, + "node_modules/@types/node": { + "version": "22.19.17", + "dev": true, + "license": "MIT", + "dependencies": { + "undici-types": "~6.21.0" + } + }, + "node_modules/@types/qs": { + "version": "6.15.0", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/range-parser": { + "version": "1.2.7", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/retry": { + "version": "0.12.0", + "dev": true, + "license": "MIT" + }, + "node_modules/@types/send": { + "version": "1.2.1", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, + "node_modules/@types/serve-static": { + "version": "2.2.0", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/http-errors": "*", + "@types/node": "*" + } + }, + "node_modules/@types/uuid": { + "version": "11.0.0", + "deprecated": "This is a stub types definition. uuid provides its own type definitions, so you do not need this installed.", + "dev": true, + "license": "MIT", + "dependencies": { + "uuid": "*" + } + }, + "node_modules/@typescript-eslint/eslint-plugin": { + "version": "8.59.0", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/regexpp": "^4.12.2", + "@typescript-eslint/scope-manager": "8.59.0", + "@typescript-eslint/type-utils": "8.59.0", + "@typescript-eslint/utils": "8.59.0", + "@typescript-eslint/visitor-keys": "8.59.0", + "ignore": "^7.0.5", + "natural-compare": "^1.4.0", + "ts-api-utils": "^2.5.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "@typescript-eslint/parser": "^8.59.0", + "eslint": "^8.57.0 || ^9.0.0 || ^10.0.0", + "typescript": ">=4.8.4 <6.1.0" + } + }, + "node_modules/@typescript-eslint/parser": { + "version": "8.59.0", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/scope-manager": "8.59.0", + "@typescript-eslint/types": "8.59.0", + "@typescript-eslint/typescript-estree": "8.59.0", + "@typescript-eslint/visitor-keys": "8.59.0", + "debug": "^4.4.3" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0 || ^10.0.0", + "typescript": ">=4.8.4 <6.1.0" + } + }, + "node_modules/@typescript-eslint/project-service": { + "version": "8.59.0", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/tsconfig-utils": "^8.59.0", + "@typescript-eslint/types": "^8.59.0", + "debug": "^4.4.3" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "typescript": ">=4.8.4 <6.1.0" + } + }, + "node_modules/@typescript-eslint/scope-manager": { + "version": "8.59.0", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "8.59.0", + "@typescript-eslint/visitor-keys": "8.59.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@typescript-eslint/tsconfig-utils": { + "version": "8.59.0", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "typescript": ">=4.8.4 <6.1.0" + } + }, + "node_modules/@typescript-eslint/type-utils": { + "version": "8.59.0", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "8.59.0", + "@typescript-eslint/typescript-estree": "8.59.0", + "@typescript-eslint/utils": "8.59.0", + "debug": "^4.4.3", + "ts-api-utils": "^2.5.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0 || ^10.0.0", + "typescript": ">=4.8.4 <6.1.0" + } + }, + "node_modules/@typescript-eslint/types": { + "version": "8.59.0", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@typescript-eslint/typescript-estree": { + "version": "8.59.0", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/project-service": "8.59.0", + "@typescript-eslint/tsconfig-utils": "8.59.0", + "@typescript-eslint/types": "8.59.0", + "@typescript-eslint/visitor-keys": "8.59.0", + "debug": "^4.4.3", + "minimatch": "^10.2.2", + "semver": "^7.7.3", + "tinyglobby": "^0.2.15", + "ts-api-utils": "^2.5.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "typescript": ">=4.8.4 <6.1.0" + } + }, + "node_modules/@typescript-eslint/utils": { + "version": "8.59.0", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/eslint-utils": "^4.9.1", + "@typescript-eslint/scope-manager": "8.59.0", + "@typescript-eslint/types": "8.59.0", + "@typescript-eslint/typescript-estree": "8.59.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0 || ^10.0.0", + "typescript": ">=4.8.4 <6.1.0" + } + }, + "node_modules/@typescript-eslint/visitor-keys": { + "version": "8.59.0", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "8.59.0", + "eslint-visitor-keys": "^5.0.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@typescript-eslint/visitor-keys/node_modules/eslint-visitor-keys": { + "version": "5.0.1", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^20.19.0 || ^22.13.0 || >=24" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/@vitest/browser": { + "version": "4.1.4", + "dev": true, + "license": "MIT", + "dependencies": { + "@blazediff/core": "1.9.1", + "@vitest/mocker": "4.1.4", + "@vitest/utils": "4.1.4", + "magic-string": "^0.30.21", + "pngjs": "^7.0.0", + "sirv": "^3.0.2", + "tinyrainbow": "^3.1.0", + "ws": "^8.19.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "vitest": "4.1.4" + } + }, + "node_modules/@vitest/browser-playwright": { + "version": "4.1.4", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/browser": "4.1.4", + "@vitest/mocker": "4.1.4", + "tinyrainbow": "^3.1.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "playwright": "*", + "vitest": "4.1.4" + }, + "peerDependenciesMeta": { + "playwright": { + "optional": false + } + } + }, + "node_modules/@vitest/coverage-v8": { + "version": "4.1.4", + "dev": true, + "license": "MIT", + "dependencies": { + "@bcoe/v8-coverage": "^1.0.2", + "@vitest/utils": "4.1.4", + "ast-v8-to-istanbul": "^1.0.0", + "istanbul-lib-coverage": "^3.2.2", + "istanbul-lib-report": "^3.0.1", + "istanbul-reports": "^3.2.0", + "magicast": "^0.5.2", + "obug": "^2.1.1", + "std-env": "^4.0.0-rc.1", + "tinyrainbow": "^3.1.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "@vitest/browser": "4.1.4", + "vitest": "4.1.4" + }, + "peerDependenciesMeta": { + "@vitest/browser": { + "optional": true + } + } + }, + "node_modules/@vitest/expect": { + "version": "4.1.4", + "dev": true, + "license": "MIT", + "dependencies": { + "@standard-schema/spec": "^1.1.0", + "@types/chai": "^5.2.2", + "@vitest/spy": "4.1.4", + "@vitest/utils": "4.1.4", + "chai": "^6.2.2", + "tinyrainbow": "^3.1.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/mocker": { + "version": "4.1.4", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/spy": "4.1.4", + "estree-walker": "^3.0.3", + "magic-string": "^0.30.21" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "msw": "^2.4.9", + "vite": "^6.0.0 || ^7.0.0 || ^8.0.0" + }, + "peerDependenciesMeta": { + "msw": { + "optional": true + }, + "vite": { + "optional": true + } + } + }, + "node_modules/@vitest/pretty-format": { + "version": "4.1.4", + "dev": true, + "license": "MIT", + "dependencies": { + "tinyrainbow": "^3.1.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/runner": { + "version": "4.1.4", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/utils": "4.1.4", + "pathe": "^2.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/snapshot": { + "version": "4.1.4", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/pretty-format": "4.1.4", + "@vitest/utils": "4.1.4", + "magic-string": "^0.30.21", + "pathe": "^2.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/spy": { + "version": "4.1.4", + "dev": true, + "license": "MIT", + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/@vitest/utils": { + "version": "4.1.4", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/pretty-format": "4.1.4", + "convert-source-map": "^2.0.0", + "tinyrainbow": "^3.1.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/accepts": { + "version": "2.0.0", + "license": "MIT", + "dependencies": { + "mime-types": "^3.0.0", + "negotiator": "^1.0.0" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/acorn": { + "version": "8.16.0", + "dev": true, + "license": "MIT", + "bin": { + "acorn": "bin/acorn" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/acorn-jsx": { + "version": "5.3.2", + "dev": true, + "license": "MIT", + "peerDependencies": { + "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" + } + }, + "node_modules/agent-base": { + "version": "7.1.4", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 14" + } + }, + "node_modules/ajv": { + "version": "8.18.0", + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.3", + "fast-uri": "^3.0.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/ajv-formats": { + "version": "3.0.1", + "license": "MIT", + "peer": true, + "dependencies": { + "ajv": "^8.0.0" + }, + "peerDependencies": { + "ajv": "^8.0.0" + }, + "peerDependenciesMeta": { + "ajv": { + "optional": true + } + } + }, + "node_modules/ansi-regex": { + "version": "6.2.2", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/ansi-regex?sponsor=1" + } + }, + "node_modules/assertion-error": { + "version": "2.0.1", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + } + }, + "node_modules/ast-v8-to-istanbul": { + "version": "1.0.0", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/trace-mapping": "^0.3.31", + "estree-walker": "^3.0.3", + "js-tokens": "^10.0.0" + } + }, + "node_modules/available-typed-arrays": { + "version": "1.0.7", + "dev": true, + "license": "MIT", + "dependencies": { + "possible-typed-array-names": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/aws4fetch": { + "version": "1.0.20", + "dev": true, + "license": "MIT" + }, + "node_modules/balanced-match": { + "version": "4.0.4", + "dev": true, + "license": "MIT", + "engines": { + "node": "18 || 20 || >=22" + } + }, + "node_modules/base64-js": { + "version": "1.5.1", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT" + }, + "node_modules/bignumber.js": { + "version": "9.3.1", + "dev": true, + "license": "MIT", + "engines": { + "node": "*" + } + }, + "node_modules/binaryen": { + "version": "123.0.0", + "dev": true, + "license": "Apache-2.0", + "bin": { + "wasm-as": "bin/wasm-as", + "wasm-ctor-eval": "bin/wasm-ctor-eval", + "wasm-dis": "bin/wasm-dis", + "wasm-merge": "bin/wasm-merge", + "wasm-metadce": "bin/wasm-metadce", + "wasm-opt": "bin/wasm-opt", + "wasm-reduce": "bin/wasm-reduce", + "wasm-shell": "bin/wasm-shell", + "wasm2js": "bin/wasm2js" + } + }, + "node_modules/bl": { + "version": "1.2.3", + "dev": true, + "license": "MIT", + "dependencies": { + "readable-stream": "^2.3.5", + "safe-buffer": "^5.1.1" + } + }, + "node_modules/body-parser": { + "version": "2.2.2", + "license": "MIT", + "dependencies": { + "bytes": "^3.1.2", + "content-type": "^1.0.5", + "debug": "^4.4.3", + "http-errors": "^2.0.0", + "iconv-lite": "^0.7.0", + "on-finished": "^2.4.1", + "qs": "^6.14.1", + "raw-body": "^3.0.1", + "type-is": "^2.0.1" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/bowser": { + "version": "2.14.1", + "license": "MIT" + }, + "node_modules/brace-expansion": { + "version": "5.0.5", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^4.0.2" + }, + "engines": { + "node": "18 || 20 || >=22" + } + }, + "node_modules/buffer": { + "version": "5.7.1", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT", + "dependencies": { + "base64-js": "^1.3.1", + "ieee754": "^1.1.13" + } + }, + "node_modules/buffer-alloc": { + "version": "1.2.0", + "dev": true, + "license": "MIT", + "dependencies": { + "buffer-alloc-unsafe": "^1.1.0", + "buffer-fill": "^1.0.0" + } + }, + "node_modules/buffer-alloc-unsafe": { + "version": "1.1.0", + "dev": true, + "license": "MIT" + }, + "node_modules/buffer-crc32": { + "version": "0.2.13", + "dev": true, + "license": "MIT", + "engines": { + "node": "*" + } + }, + "node_modules/buffer-equal-constant-time": { + "version": "1.0.1", + "dev": true, + "license": "BSD-3-Clause" + }, + "node_modules/buffer-fill": { + "version": "1.0.0", + "dev": true, + "license": "MIT" + }, + "node_modules/buffer-from": { + "version": "1.1.2", + "dev": true, + "license": "MIT" + }, + "node_modules/bytes": { + "version": "3.1.2", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/cac": { + "version": "6.7.14", + "resolved": "https://registry.npmjs.org/cac/-/cac-6.7.14.tgz", + "integrity": "sha512-b6Ilus+c3RrdDk+JhLKUAQfzzgLEPy6wcXqS7f/xe1EETvsDP6GORG7SFuOs6cID5YkqchW/LXZbX5bc8j7ZcQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/call-bind": { + "version": "1.0.9", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "es-define-property": "^1.0.1", + "get-intrinsic": "^1.3.0", + "set-function-length": "^1.2.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/call-bind-apply-helpers": { + "version": "1.0.2", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/call-bound": { + "version": "1.0.4", + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "get-intrinsic": "^1.3.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/chai": { + "version": "6.2.2", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + } + }, + "node_modules/check-error": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/check-error/-/check-error-2.1.3.tgz", + "integrity": "sha512-PAJdDJusoxnwm1VwW07VWwUN1sl7smmC3OKggvndJFadxxDRyFJBX/ggnu/KE4kQAB7a3Dp8f/YXC1FlUprWmA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 16" + } + }, + "node_modules/cli-cursor": { + "version": "5.0.0", + "dev": true, + "license": "MIT", + "dependencies": { + "restore-cursor": "^5.0.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/cli-spinners": { + "version": "2.9.2", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/commander": { + "version": "14.0.3", + "resolved": "https://registry.npmjs.org/commander/-/commander-14.0.3.tgz", + "integrity": "sha512-H+y0Jo/T1RZ9qPP4Eh1pkcQcLRglraJaSLoyOtHxu6AapkjWVCy2Sit1QQ4x3Dng8qDlSsZEet7g5Pq06MvTgw==", + "license": "MIT", + "engines": { + "node": ">=20" + } + }, + "node_modules/content-disposition": { + "version": "1.1.0", + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/content-type": { + "version": "1.0.5", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/convert-source-map": { + "version": "2.0.0", + "dev": true, + "license": "MIT" + }, + "node_modules/cookie": { + "version": "0.7.2", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/cookie-signature": { + "version": "1.2.2", + "license": "MIT", + "engines": { + "node": ">=6.6.0" + } + }, + "node_modules/core-util-is": { + "version": "1.0.3", + "dev": true, + "license": "MIT" + }, + "node_modules/cors": { + "version": "2.8.6", + "license": "MIT", + "peer": true, + "dependencies": { + "object-assign": "^4", + "vary": "^1" + }, + "engines": { + "node": ">= 0.10" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/cross-spawn": { + "version": "7.0.6", + "license": "MIT", + "dependencies": { + "path-key": "^3.1.0", + "shebang-command": "^2.0.0", + "which": "^2.0.1" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/data-uri-to-buffer": { + "version": "4.0.1", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 12" + } + }, + "node_modules/debug": { + "version": "4.4.3", + "license": "MIT", + "dependencies": { + "ms": "^2.1.3" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } + } + }, + "node_modules/decompress": { + "version": "4.2.1", + "dev": true, + "license": "MIT", + "dependencies": { + "decompress-tar": "^4.0.0", + "decompress-tarbz2": "^4.0.0", + "decompress-targz": "^4.0.0", + "decompress-unzip": "^4.0.1", + "graceful-fs": "^4.1.10", + "make-dir": "^1.0.0", + "pify": "^2.3.0", + "strip-dirs": "^2.0.0" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/decompress-tar": { + "version": "4.1.1", + "dev": true, + "license": "MIT", + "dependencies": { + "file-type": "^5.2.0", + "is-stream": "^1.1.0", + "tar-stream": "^1.5.2" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/decompress-tarbz2": { + "version": "4.1.1", + "dev": true, + "license": "MIT", + "dependencies": { + "decompress-tar": "^4.1.0", + "file-type": "^6.1.0", + "is-stream": "^1.1.0", + "seek-bzip": "^1.0.5", + "unbzip2-stream": "^1.0.9" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/decompress-tarbz2/node_modules/file-type": { + "version": "6.2.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=4" + } + }, + "node_modules/decompress-targz": { + "version": "4.1.1", + "dev": true, + "license": "MIT", + "dependencies": { + "decompress-tar": "^4.1.1", + "file-type": "^5.2.0", + "is-stream": "^1.1.0" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/decompress-unzip": { + "version": "4.0.1", + "dev": true, + "license": "MIT", + "dependencies": { + "file-type": "^3.8.0", + "get-stream": "^2.2.0", + "pify": "^2.3.0", + "yauzl": "^2.4.2" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/decompress-unzip/node_modules/file-type": { + "version": "3.9.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/deep-eql": { + "version": "5.0.2", + "resolved": "https://registry.npmjs.org/deep-eql/-/deep-eql-5.0.2.tgz", + "integrity": "sha512-h5k/5U50IJJFpzfL6nO9jaaumfjO/f2NjK/oYB2Djzm4p9L+3T9qWpZqZ2hAbLPuuYq9wrU08WQyBTL5GbPk5Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/deep-is": { + "version": "0.1.4", + "dev": true, + "license": "MIT" + }, + "node_modules/define-data-property": { + "version": "1.1.4", + "dev": true, + "license": "MIT", + "dependencies": { + "es-define-property": "^1.0.0", + "es-errors": "^1.3.0", + "gopd": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/depd": { + "version": "2.0.0", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/dunder-proto": { + "version": "1.0.1", + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.1", + "es-errors": "^1.3.0", + "gopd": "^1.2.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/ecdsa-sig-formatter": { + "version": "1.0.11", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "safe-buffer": "^5.0.1" + } + }, + "node_modules/ee-first": { + "version": "1.1.1", + "license": "MIT" + }, + "node_modules/emoji-regex": { + "version": "10.6.0", + "dev": true, + "license": "MIT" + }, + "node_modules/encodeurl": { + "version": "2.0.0", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/end-of-stream": { + "version": "1.4.5", + "dev": true, + "license": "MIT", + "dependencies": { + "once": "^1.4.0" + } + }, + "node_modules/es-define-property": { + "version": "1.0.1", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-errors": { + "version": "1.3.0", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-module-lexer": { + "version": "1.7.0", + "dev": true, + "license": "MIT" + }, + "node_modules/es-object-atoms": { + "version": "1.1.1", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/esbuild": { + "version": "0.27.7", + "hasInstallScript": true, + "license": "MIT", + "bin": { + "esbuild": "bin/esbuild" + }, + "engines": { + "node": ">=18" + }, + "optionalDependencies": { + "@esbuild/aix-ppc64": "0.27.7", + "@esbuild/android-arm": "0.27.7", + "@esbuild/android-arm64": "0.27.7", + "@esbuild/android-x64": "0.27.7", + "@esbuild/darwin-arm64": "0.27.7", + "@esbuild/darwin-x64": "0.27.7", + "@esbuild/freebsd-arm64": "0.27.7", + "@esbuild/freebsd-x64": "0.27.7", + "@esbuild/linux-arm": "0.27.7", + "@esbuild/linux-arm64": "0.27.7", + "@esbuild/linux-ia32": "0.27.7", + "@esbuild/linux-loong64": "0.27.7", + "@esbuild/linux-mips64el": "0.27.7", + "@esbuild/linux-ppc64": "0.27.7", + "@esbuild/linux-riscv64": "0.27.7", + "@esbuild/linux-s390x": "0.27.7", + "@esbuild/linux-x64": "0.27.7", + "@esbuild/netbsd-arm64": "0.27.7", + "@esbuild/netbsd-x64": "0.27.7", + "@esbuild/openbsd-arm64": "0.27.7", + "@esbuild/openbsd-x64": "0.27.7", + "@esbuild/openharmony-arm64": "0.27.7", + "@esbuild/sunos-x64": "0.27.7", + "@esbuild/win32-arm64": "0.27.7", + "@esbuild/win32-ia32": "0.27.7", + "@esbuild/win32-x64": "0.27.7" + } + }, + "node_modules/esbuild/node_modules/@esbuild/aix-ppc64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.27.7.tgz", + "integrity": "sha512-EKX3Qwmhz1eMdEJokhALr0YiD0lhQNwDqkPYyPhiSwKrh7/4KRjQc04sZ8db+5DVVnZ1LmbNDI1uAMPEUBnQPg==", + "cpu": [ + "ppc64" + ], + "license": "MIT", + "optional": true, + "os": [ + "aix" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/android-arm": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.27.7.tgz", + "integrity": "sha512-jbPXvB4Yj2yBV7HUfE2KHe4GJX51QplCN1pGbYjvsyCZbQmies29EoJbkEc+vYuU5o45AfQn37vZlyXy4YJ8RQ==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/android-arm64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.27.7.tgz", + "integrity": "sha512-62dPZHpIXzvChfvfLJow3q5dDtiNMkwiRzPylSCfriLvZeq0a1bWChrGx/BbUbPwOrsWKMn8idSllklzBy+dgQ==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/android-x64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.27.7.tgz", + "integrity": "sha512-x5VpMODneVDb70PYV2VQOmIUUiBtY3D3mPBG8NxVk5CogneYhkR7MmM3yR/uMdITLrC1ml/NV1rj4bMJuy9MCg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/darwin-x64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.27.7.tgz", + "integrity": "sha512-rYnXrKcXuT7Z+WL5K980jVFdvVKhCHhUwid+dDYQpH+qu+TefcomiMAJpIiC2EM3Rjtq0sO3StMV/+3w3MyyqQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/freebsd-arm64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.27.7.tgz", + "integrity": "sha512-B48PqeCsEgOtzME2GbNM2roU29AMTuOIN91dsMO30t+Ydis3z/3Ngoj5hhnsOSSwNzS+6JppqWsuhTp6E82l2w==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/freebsd-x64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.27.7.tgz", + "integrity": "sha512-jOBDK5XEjA4m5IJK3bpAQF9/Lelu/Z9ZcdhTRLf4cajlB+8VEhFFRjWgfy3M1O4rO2GQ/b2dLwCUGpiF/eATNQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/linux-arm": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.27.7.tgz", + "integrity": "sha512-RkT/YXYBTSULo3+af8Ib0ykH8u2MBh57o7q/DAs3lTJlyVQkgQvlrPTnjIzzRPQyavxtPtfg0EopvDyIt0j1rA==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/linux-arm64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.27.7.tgz", + "integrity": "sha512-RZPHBoxXuNnPQO9rvjh5jdkRmVizktkT7TCDkDmQ0W2SwHInKCAV95GRuvdSvA7w4VMwfCjUiPwDi0ZO6Nfe9A==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/linux-ia32": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.27.7.tgz", + "integrity": "sha512-GA48aKNkyQDbd3KtkplYWT102C5sn/EZTY4XROkxONgruHPU72l+gW+FfF8tf2cFjeHaRbWpOYa/uRBz/Xq1Pg==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/linux-loong64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.27.7.tgz", + "integrity": "sha512-a4POruNM2oWsD4WKvBSEKGIiWQF8fZOAsycHOt6JBpZ+JN2n2JH9WAv56SOyu9X5IqAjqSIPTaJkqN8F7XOQ5Q==", + "cpu": [ + "loong64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/linux-mips64el": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.27.7.tgz", + "integrity": "sha512-KabT5I6StirGfIz0FMgl1I+R1H73Gp0ofL9A3nG3i/cYFJzKHhouBV5VWK1CSgKvVaG4q1RNpCTR2LuTVB3fIw==", + "cpu": [ + "mips64el" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/linux-ppc64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.27.7.tgz", + "integrity": "sha512-gRsL4x6wsGHGRqhtI+ifpN/vpOFTQtnbsupUF5R5YTAg+y/lKelYR1hXbnBdzDjGbMYjVJLJTd2OFmMewAgwlQ==", + "cpu": [ + "ppc64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/linux-riscv64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.27.7.tgz", + "integrity": "sha512-hL25LbxO1QOngGzu2U5xeXtxXcW+/GvMN3ejANqXkxZ/opySAZMrc+9LY/WyjAan41unrR3YrmtTsUpwT66InQ==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/linux-s390x": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.27.7.tgz", + "integrity": "sha512-2k8go8Ycu1Kb46vEelhu1vqEP+UeRVj2zY1pSuPdgvbd5ykAw82Lrro28vXUrRmzEsUV0NzCf54yARIK8r0fdw==", + "cpu": [ + "s390x" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/linux-x64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.27.7.tgz", + "integrity": "sha512-hzznmADPt+OmsYzw1EE33ccA+HPdIqiCRq7cQeL1Jlq2gb1+OyWBkMCrYGBJ+sxVzve2ZJEVeePbLM2iEIZSxA==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/netbsd-arm64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-arm64/-/netbsd-arm64-0.27.7.tgz", + "integrity": "sha512-b6pqtrQdigZBwZxAn1UpazEisvwaIDvdbMbmrly7cDTMFnw/+3lVxxCTGOrkPVnsYIosJJXAsILG9XcQS+Yu6w==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "netbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/netbsd-x64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.27.7.tgz", + "integrity": "sha512-OfatkLojr6U+WN5EDYuoQhtM+1xco+/6FSzJJnuWiUw5eVcicbyK3dq5EeV/QHT1uy6GoDhGbFpprUiHUYggrw==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "netbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/openbsd-arm64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-arm64/-/openbsd-arm64-0.27.7.tgz", + "integrity": "sha512-AFuojMQTxAz75Fo8idVcqoQWEHIXFRbOc1TrVcFSgCZtQfSdc1RXgB3tjOn/krRHENUB4j00bfGjyl2mJrU37A==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/openbsd-x64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.27.7.tgz", + "integrity": "sha512-+A1NJmfM8WNDv5CLVQYJ5PshuRm/4cI6WMZRg1by1GwPIQPCTs1GLEUHwiiQGT5zDdyLiRM/l1G0Pv54gvtKIg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/openharmony-arm64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/openharmony-arm64/-/openharmony-arm64-0.27.7.tgz", + "integrity": "sha512-+KrvYb/C8zA9CU/g0sR6w2RBw7IGc5J2BPnc3dYc5VJxHCSF1yNMxTV5LQ7GuKteQXZtspjFbiuW5/dOj7H4Yw==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "openharmony" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/sunos-x64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.27.7.tgz", + "integrity": "sha512-ikktIhFBzQNt/QDyOL580ti9+5mL/YZeUPKU2ivGtGjdTYoqz6jObj6nOMfhASpS4GU4Q/Clh1QtxWAvcYKamA==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "sunos" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/win32-arm64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.27.7.tgz", + "integrity": "sha512-7yRhbHvPqSpRUV7Q20VuDwbjW5kIMwTHpptuUzV+AA46kiPze5Z7qgt6CLCK3pWFrHeNfDd1VKgyP4O+ng17CA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/win32-ia32": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.27.7.tgz", + "integrity": "sha512-SmwKXe6VHIyZYbBLJrhOoCJRB/Z1tckzmgTLfFYOfpMAx63BJEaL9ExI8x7v0oAO3Zh6D/Oi1gVxEYr5oUCFhw==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/esbuild/node_modules/@esbuild/win32-x64": { + "version": "0.27.7", + "resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.27.7.tgz", + "integrity": "sha512-56hiAJPhwQ1R4i+21FVF7V8kSD5zZTdHcVuRFMW0hn753vVfQN8xlx4uOPT4xoGH0Z/oVATuR82AiqSTDIpaHg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/escape-html": { + "version": "1.0.3", + "license": "MIT" + }, + "node_modules/escape-string-regexp": { + "version": "4.0.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/eslint": { + "version": "10.2.1", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/eslint-utils": "^4.8.0", + "@eslint-community/regexpp": "^4.12.2", + "@eslint/config-array": "^0.23.5", + "@eslint/config-helpers": "^0.5.5", + "@eslint/core": "^1.2.1", + "@eslint/plugin-kit": "^0.7.1", + "@humanfs/node": "^0.16.6", + "@humanwhocodes/module-importer": "^1.0.1", + "@humanwhocodes/retry": "^0.4.2", + "@types/estree": "^1.0.6", + "ajv": "^6.14.0", + "cross-spawn": "^7.0.6", + "debug": "^4.3.2", + "escape-string-regexp": "^4.0.0", + "eslint-scope": "^9.1.2", + "eslint-visitor-keys": "^5.0.1", + "espree": "^11.2.0", + "esquery": "^1.7.0", + "esutils": "^2.0.2", + "fast-deep-equal": "^3.1.3", + "file-entry-cache": "^8.0.0", + "find-up": "^5.0.0", + "glob-parent": "^6.0.2", + "ignore": "^5.2.0", + "imurmurhash": "^0.1.4", + "is-glob": "^4.0.0", + "json-stable-stringify-without-jsonify": "^1.0.1", + "minimatch": "^10.2.4", + "natural-compare": "^1.4.0", + "optionator": "^0.9.3" + }, + "bin": { + "eslint": "bin/eslint.js" + }, + "engines": { + "node": "^20.19.0 || ^22.13.0 || >=24" + }, + "funding": { + "url": "https://eslint.org/donate" + }, + "peerDependencies": { + "jiti": "*" + }, + "peerDependenciesMeta": { + "jiti": { + "optional": true + } + } + }, + "node_modules/eslint-plugin-tsdoc": { + "version": "0.5.2", + "dev": true, + "license": "MIT", + "dependencies": { + "@microsoft/tsdoc": "0.16.0", + "@microsoft/tsdoc-config": "0.18.1", + "@typescript-eslint/utils": "~8.56.0" + } + }, + "node_modules/eslint-plugin-tsdoc/node_modules/@typescript-eslint/project-service": { + "version": "8.56.1", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/tsconfig-utils": "^8.56.1", + "@typescript-eslint/types": "^8.56.1", + "debug": "^4.4.3" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/eslint-plugin-tsdoc/node_modules/@typescript-eslint/scope-manager": { + "version": "8.56.1", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "8.56.1", + "@typescript-eslint/visitor-keys": "8.56.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/eslint-plugin-tsdoc/node_modules/@typescript-eslint/tsconfig-utils": { + "version": "8.56.1", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/eslint-plugin-tsdoc/node_modules/@typescript-eslint/types": { + "version": "8.56.1", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/eslint-plugin-tsdoc/node_modules/@typescript-eslint/typescript-estree": { + "version": "8.56.1", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/project-service": "8.56.1", + "@typescript-eslint/tsconfig-utils": "8.56.1", + "@typescript-eslint/types": "8.56.1", + "@typescript-eslint/visitor-keys": "8.56.1", + "debug": "^4.4.3", + "minimatch": "^10.2.2", + "semver": "^7.7.3", + "tinyglobby": "^0.2.15", + "ts-api-utils": "^2.4.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/eslint-plugin-tsdoc/node_modules/@typescript-eslint/utils": { + "version": "8.56.1", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/eslint-utils": "^4.9.1", + "@typescript-eslint/scope-manager": "8.56.1", + "@typescript-eslint/types": "8.56.1", + "@typescript-eslint/typescript-estree": "8.56.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0 || ^10.0.0", + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/eslint-plugin-tsdoc/node_modules/@typescript-eslint/visitor-keys": { + "version": "8.56.1", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "8.56.1", + "eslint-visitor-keys": "^5.0.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/eslint-plugin-tsdoc/node_modules/eslint-visitor-keys": { + "version": "5.0.1", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^20.19.0 || ^22.13.0 || >=24" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/eslint-scope": { + "version": "9.1.2", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "@types/esrecurse": "^4.3.1", + "@types/estree": "^1.0.8", + "esrecurse": "^4.3.0", + "estraverse": "^5.2.0" + }, + "engines": { + "node": "^20.19.0 || ^22.13.0 || >=24" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/eslint-visitor-keys": { + "version": "3.4.3", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^12.22.0 || ^14.17.0 || >=16.0.0" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/eslint/node_modules/ajv": { + "version": "6.14.0", + "dev": true, + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.1", + "fast-json-stable-stringify": "^2.0.0", + "json-schema-traverse": "^0.4.1", + "uri-js": "^4.2.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/eslint/node_modules/eslint-visitor-keys": { + "version": "5.0.1", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^20.19.0 || ^22.13.0 || >=24" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/eslint/node_modules/ignore": { + "version": "5.3.2", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 4" + } + }, + "node_modules/eslint/node_modules/json-schema-traverse": { + "version": "0.4.1", + "dev": true, + "license": "MIT" + }, + "node_modules/espree": { + "version": "11.2.0", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "acorn": "^8.16.0", + "acorn-jsx": "^5.3.2", + "eslint-visitor-keys": "^5.0.1" + }, + "engines": { + "node": "^20.19.0 || ^22.13.0 || >=24" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/espree/node_modules/eslint-visitor-keys": { + "version": "5.0.1", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^20.19.0 || ^22.13.0 || >=24" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, + "node_modules/esquery": { + "version": "1.7.0", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "estraverse": "^5.1.0" + }, + "engines": { + "node": ">=0.10" + } + }, + "node_modules/esrecurse": { + "version": "4.3.0", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "estraverse": "^5.2.0" + }, + "engines": { + "node": ">=4.0" + } + }, + "node_modules/estraverse": { + "version": "5.3.0", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=4.0" + } + }, + "node_modules/estree-walker": { + "version": "3.0.3", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/estree": "^1.0.0" + } + }, + "node_modules/esutils": { + "version": "2.0.3", + "dev": true, + "license": "BSD-2-Clause", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/etag": { + "version": "1.8.1", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/eventsource": { + "version": "3.0.7", + "license": "MIT", + "peer": true, + "dependencies": { + "eventsource-parser": "^3.0.1" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/eventsource-parser": { + "version": "3.0.8", + "license": "MIT", + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/expect-type": { + "version": "1.3.0", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=12.0.0" + } + }, + "node_modules/express": { + "version": "5.2.1", + "license": "MIT", + "dependencies": { + "accepts": "^2.0.0", + "body-parser": "^2.2.1", + "content-disposition": "^1.0.0", + "content-type": "^1.0.5", + "cookie": "^0.7.1", + "cookie-signature": "^1.2.1", + "debug": "^4.4.0", + "depd": "^2.0.0", + "encodeurl": "^2.0.0", + "escape-html": "^1.0.3", + "etag": "^1.8.1", + "finalhandler": "^2.1.0", + "fresh": "^2.0.0", + "http-errors": "^2.0.0", + "merge-descriptors": "^2.0.0", + "mime-types": "^3.0.0", + "on-finished": "^2.4.1", + "once": "^1.4.0", + "parseurl": "^1.3.3", + "proxy-addr": "^2.0.7", + "qs": "^6.14.0", + "range-parser": "^1.2.1", + "router": "^2.2.0", + "send": "^1.1.0", + "serve-static": "^2.2.0", + "statuses": "^2.0.1", + "type-is": "^2.0.1", + "vary": "^1.1.2" + }, + "engines": { + "node": ">= 18" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/express-rate-limit": { + "version": "8.5.1", + "resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.5.1.tgz", + "integrity": "sha512-5O6KYmyJEpuPJV5hNTXKbAHWRqrzyu+OI3vUnSd2kXFubIVpG7ezpgxQy76Zo5GQZtrQBg86hF+CM/NX+cioiQ==", + "license": "MIT", + "peer": true, + "dependencies": { + "ip-address": "^10.2.0" + }, + "engines": { + "node": ">= 16" + }, + "funding": { + "url": "https://github.com/sponsors/express-rate-limit" + }, + "peerDependencies": { + "express": ">= 4.11" + } + }, + "node_modules/extend": { + "version": "3.0.2", + "dev": true, + "license": "MIT" + }, + "node_modules/fast-deep-equal": { + "version": "3.1.3", + "license": "MIT" + }, + "node_modules/fast-json-stable-stringify": { + "version": "2.1.0", + "dev": true, + "license": "MIT" + }, + "node_modules/fast-levenshtein": { + "version": "2.0.6", + "dev": true, + "license": "MIT" + }, + "node_modules/fast-uri": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.1.2.tgz", + "integrity": "sha512-rVjf7ArG3LTk+FS6Yw81V1DLuZl1bRbNrev6Tmd/9RaroeeRRJhAt7jg/6YFxbvAQXUCavSoZhPPj6oOx+5KjQ==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/fastify" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/fastify" + } + ], + "license": "BSD-3-Clause" + }, + "node_modules/fast-xml-builder": { + "version": "1.1.7", + "resolved": "https://registry.npmjs.org/fast-xml-builder/-/fast-xml-builder-1.1.7.tgz", + "integrity": "sha512-Yh7/7rQuMXICNr0oMYDR2yHP6oUvmQsTToFeOWj/kIDhAwQ+c4Ol/lbcwOmEM5OHYQmh6S6EQSQ1sljCKP36bQ==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/NaturalIntelligence" + } + ], + "license": "MIT", + "dependencies": { + "path-expression-matcher": "^1.1.3" + } + }, + "node_modules/fast-xml-parser": { + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/fast-xml-parser/-/fast-xml-parser-5.7.2.tgz", + "integrity": "sha512-P7oW7tLbYnhOLQk/Gv7cZgzgMPP/XN03K02/Jy6Y/NHzyIAIpxuZIM/YqAkfiXFPxA2CTm7NtCijK9EDu09u2w==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/NaturalIntelligence" + } + ], + "license": "MIT", + "dependencies": { + "@nodable/entities": "^2.1.0", + "fast-xml-builder": "^1.1.5", + "path-expression-matcher": "^1.5.0", + "strnum": "^2.2.3" + }, + "bin": { + "fxparser": "src/cli/cli.js" + } + }, + "node_modules/fd-slicer": { + "version": "1.1.0", + "dev": true, + "license": "MIT", + "dependencies": { + "pend": "~1.2.0" + } + }, + "node_modules/fdir": { + "version": "6.5.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12.0.0" + }, + "peerDependencies": { + "picomatch": "^3 || ^4" + }, + "peerDependenciesMeta": { + "picomatch": { + "optional": true + } + } + }, + "node_modules/fetch-blob": { + "version": "3.2.0", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/jimmywarting" + }, + { + "type": "paypal", + "url": "https://paypal.me/jimmywarting" + } + ], + "license": "MIT", + "dependencies": { + "node-domexception": "^1.0.0", + "web-streams-polyfill": "^3.0.3" + }, + "engines": { + "node": "^12.20 || >= 14.13" + } + }, + "node_modules/file-entry-cache": { + "version": "8.0.0", + "dev": true, + "license": "MIT", + "dependencies": { + "flat-cache": "^4.0.0" + }, + "engines": { + "node": ">=16.0.0" + } + }, + "node_modules/file-type": { + "version": "5.2.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=4" + } + }, + "node_modules/finalhandler": { + "version": "2.1.1", + "license": "MIT", + "dependencies": { + "debug": "^4.4.0", + "encodeurl": "^2.0.0", + "escape-html": "^1.0.3", + "on-finished": "^2.4.1", + "parseurl": "^1.3.3", + "statuses": "^2.0.1" + }, + "engines": { + "node": ">= 18.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/find-up": { + "version": "5.0.0", + "dev": true, + "license": "MIT", + "dependencies": { + "locate-path": "^6.0.0", + "path-exists": "^4.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/flat-cache": { + "version": "4.0.1", + "dev": true, + "license": "MIT", + "dependencies": { + "flatted": "^3.2.9", + "keyv": "^4.5.4" + }, + "engines": { + "node": ">=16" + } + }, + "node_modules/flatted": { + "version": "3.4.2", + "dev": true, + "license": "ISC" + }, + "node_modules/for-each": { + "version": "0.3.5", + "dev": true, + "license": "MIT", + "dependencies": { + "is-callable": "^1.2.7" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/formdata-polyfill": { + "version": "4.0.10", + "dev": true, + "license": "MIT", + "dependencies": { + "fetch-blob": "^3.1.2" + }, + "engines": { + "node": ">=12.20.0" + } + }, + "node_modules/forwarded": { + "version": "0.2.0", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/fresh": { + "version": "2.0.0", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/fs-constants": { + "version": "1.0.0", + "dev": true, + "license": "MIT" + }, + "node_modules/fsevents": { + "version": "2.3.2", + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, + "node_modules/function-bind": { + "version": "1.1.2", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/gaxios": { + "version": "7.1.4", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "extend": "^3.0.2", + "https-proxy-agent": "^7.0.1", + "node-fetch": "^3.3.2" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/gcp-metadata": { + "version": "8.1.2", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "gaxios": "^7.0.0", + "google-logging-utils": "^1.0.0", + "json-bigint": "^1.0.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/get-east-asian-width": { + "version": "1.5.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/get-intrinsic": { + "version": "1.3.0", + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "es-define-property": "^1.0.1", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.1.1", + "function-bind": "^1.1.2", + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "has-symbols": "^1.1.0", + "hasown": "^2.0.2", + "math-intrinsics": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/get-proto": { + "version": "1.0.1", + "license": "MIT", + "dependencies": { + "dunder-proto": "^1.0.1", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/get-stream": { + "version": "2.3.1", + "dev": true, + "license": "MIT", + "dependencies": { + "object-assign": "^4.0.1", + "pinkie-promise": "^2.0.0" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/get-tsconfig": { + "version": "4.14.0", + "license": "MIT", + "dependencies": { + "resolve-pkg-maps": "^1.0.0" + }, + "funding": { + "url": "https://github.com/privatenumber/get-tsconfig?sponsor=1" + } + }, + "node_modules/glob-parent": { + "version": "6.0.2", + "dev": true, + "license": "ISC", + "dependencies": { + "is-glob": "^4.0.3" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/google-auth-library": { + "version": "10.6.2", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "base64-js": "^1.3.0", + "ecdsa-sig-formatter": "^1.0.11", + "gaxios": "^7.1.4", + "gcp-metadata": "8.1.2", + "google-logging-utils": "1.1.3", + "jws": "^4.0.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/google-logging-utils": { + "version": "1.1.3", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=14" + } + }, + "node_modules/gopd": { + "version": "1.2.0", + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/graceful-fs": { + "version": "4.2.11", + "dev": true, + "license": "ISC" + }, + "node_modules/has-flag": { + "version": "4.0.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/has-property-descriptors": { + "version": "1.0.2", + "dev": true, + "license": "MIT", + "dependencies": { + "es-define-property": "^1.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-symbols": { + "version": "1.1.0", + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-tostringtag": { + "version": "1.0.2", + "dev": true, + "license": "MIT", + "dependencies": { + "has-symbols": "^1.0.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/hasown": { + "version": "2.0.3", + "license": "MIT", + "dependencies": { + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/hono": { + "version": "4.12.18", + "resolved": "https://registry.npmjs.org/hono/-/hono-4.12.18.tgz", + "integrity": "sha512-RWzP96k/yv0PQfyXnWjs6zot20TqfpfsNXhOnev8d1InAxubW93L11/oNUc3tQqn2G0bSdAOBpX+2uDFHV7kdQ==", + "license": "MIT", + "peer": true, + "engines": { + "node": ">=16.9.0" + } + }, + "node_modules/html-escaper": { + "version": "2.0.2", + "dev": true, + "license": "MIT" + }, + "node_modules/http-errors": { + "version": "2.0.1", + "license": "MIT", + "dependencies": { + "depd": "~2.0.0", + "inherits": "~2.0.4", + "setprototypeof": "~1.2.0", + "statuses": "~2.0.2", + "toidentifier": "~1.0.1" + }, + "engines": { + "node": ">= 0.8" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/https-proxy-agent": { + "version": "7.0.6", + "dev": true, + "license": "MIT", + "dependencies": { + "agent-base": "^7.1.2", + "debug": "4" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/husky": { + "version": "9.1.7", + "dev": true, + "license": "MIT", + "bin": { + "husky": "bin.js" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/typicode" + } + }, + "node_modules/iconv-lite": { + "version": "0.7.2", + "license": "MIT", + "dependencies": { + "safer-buffer": ">= 2.1.2 < 3.0.0" + }, + "engines": { + "node": ">=0.10.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/ieee754": { + "version": "1.2.1", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "BSD-3-Clause" + }, + "node_modules/ignore": { + "version": "7.0.5", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 4" + } + }, + "node_modules/imurmurhash": { + "version": "0.1.4", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.8.19" + } + }, + "node_modules/inherits": { + "version": "2.0.4", + "license": "ISC" + }, + "node_modules/ip-address": { + "version": "10.2.0", + "resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.2.0.tgz", + "integrity": "sha512-/+S6j4E9AHvW9SWMSEY9Xfy66O5PWvVEJ08O0y5JGyEKQpojb0K0GKpz/v5HJ/G0vi3D2sjGK78119oXZeE0qA==", + "license": "MIT", + "peer": true, + "engines": { + "node": ">= 12" + } + }, + "node_modules/ipaddr.js": { + "version": "1.9.1", + "license": "MIT", + "engines": { + "node": ">= 0.10" + } + }, + "node_modules/is-callable": { + "version": "1.2.7", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-core-module": { + "version": "2.16.1", + "dev": true, + "license": "MIT", + "dependencies": { + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-extglob": { + "version": "2.1.1", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-glob": { + "version": "4.0.3", + "dev": true, + "license": "MIT", + "dependencies": { + "is-extglob": "^2.1.1" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-interactive": { + "version": "2.0.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/is-natural-number": { + "version": "4.0.1", + "dev": true, + "license": "MIT" + }, + "node_modules/is-promise": { + "version": "4.0.0", + "license": "MIT" + }, + "node_modules/is-stream": { + "version": "1.1.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-typed-array": { + "version": "1.1.15", + "dev": true, + "license": "MIT", + "dependencies": { + "which-typed-array": "^1.1.16" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-unicode-supported": { + "version": "2.1.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/isarray": { + "version": "1.0.0", + "dev": true, + "license": "MIT" + }, + "node_modules/isexe": { + "version": "2.0.0", + "license": "ISC" + }, + "node_modules/istanbul-lib-coverage": { + "version": "3.2.2", + "dev": true, + "license": "BSD-3-Clause", + "engines": { + "node": ">=8" + } + }, + "node_modules/istanbul-lib-report": { + "version": "3.0.1", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "istanbul-lib-coverage": "^3.0.0", + "make-dir": "^4.0.0", + "supports-color": "^7.1.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/istanbul-lib-report/node_modules/make-dir": { + "version": "4.0.0", + "dev": true, + "license": "MIT", + "dependencies": { + "semver": "^7.5.3" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/istanbul-reports": { + "version": "3.2.0", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "html-escaper": "^2.0.0", + "istanbul-lib-report": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/jju": { + "version": "1.4.0", + "dev": true, + "license": "MIT" + }, + "node_modules/jose": { + "version": "6.2.2", + "license": "MIT", + "peer": true, + "funding": { + "url": "https://github.com/sponsors/panva" + } + }, + "node_modules/js-tokens": { + "version": "10.0.0", + "dev": true, + "license": "MIT" + }, + "node_modules/json-bigint": { + "version": "1.0.0", + "dev": true, + "license": "MIT", + "dependencies": { + "bignumber.js": "^9.0.0" + } + }, + "node_modules/json-buffer": { + "version": "3.0.1", + "dev": true, + "license": "MIT" + }, + "node_modules/json-schema": { + "version": "0.4.0", + "dev": true, + "license": "(AFL-2.1 OR BSD-3-Clause)" + }, + "node_modules/json-schema-to-ts": { + "version": "3.1.1", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/runtime": "^7.18.3", + "ts-algebra": "^2.0.0" + }, + "engines": { + "node": ">=16" + } + }, + "node_modules/json-schema-traverse": { + "version": "1.0.0", + "license": "MIT" + }, + "node_modules/json-schema-typed": { + "version": "8.0.2", + "license": "BSD-2-Clause", + "peer": true + }, + "node_modules/json-stable-stringify-without-jsonify": { + "version": "1.0.1", + "dev": true, + "license": "MIT" + }, + "node_modules/jwa": { + "version": "2.0.1", + "dev": true, + "license": "MIT", + "dependencies": { + "buffer-equal-constant-time": "^1.0.1", + "ecdsa-sig-formatter": "1.0.11", + "safe-buffer": "^5.0.1" + } + }, + "node_modules/jws": { + "version": "4.0.1", + "dev": true, + "license": "MIT", + "dependencies": { + "jwa": "^2.0.1", + "safe-buffer": "^5.0.1" + } + }, + "node_modules/keyv": { + "version": "4.5.4", + "dev": true, + "license": "MIT", + "dependencies": { + "json-buffer": "3.0.1" + } + }, + "node_modules/levn": { + "version": "0.4.1", + "dev": true, + "license": "MIT", + "dependencies": { + "prelude-ls": "^1.2.1", + "type-check": "~0.4.0" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/locate-path": { + "version": "6.0.0", + "dev": true, + "license": "MIT", + "dependencies": { + "p-locate": "^5.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/log-symbols": { + "version": "6.0.0", + "dev": true, + "license": "MIT", + "dependencies": { + "chalk": "^5.3.0", + "is-unicode-supported": "^1.3.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/log-symbols/node_modules/chalk": { + "version": "5.6.2", + "dev": true, + "license": "MIT", + "engines": { + "node": "^12.17.0 || ^14.13 || >=16.0.0" + }, + "funding": { + "url": "https://github.com/chalk/chalk?sponsor=1" + } + }, + "node_modules/log-symbols/node_modules/is-unicode-supported": { + "version": "1.3.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/long": { + "version": "5.3.2", + "dev": true, + "license": "Apache-2.0" + }, + "node_modules/loupe": { + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/loupe/-/loupe-3.2.1.tgz", + "integrity": "sha512-CdzqowRJCeLU72bHvWqwRBBlLcMEtIvGrlvef74kMnV2AolS9Y8xUv1I0U/MNAWMhBlKIoyuEgoJ0t/bbwHbLQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/magic-string": { + "version": "0.30.21", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/sourcemap-codec": "^1.5.5" + } + }, + "node_modules/magicast": { + "version": "0.5.2", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/parser": "^7.29.0", + "@babel/types": "^7.29.0", + "source-map-js": "^1.2.1" + } + }, + "node_modules/make-dir": { + "version": "1.3.0", + "dev": true, + "license": "MIT", + "dependencies": { + "pify": "^3.0.0" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/make-dir/node_modules/pify": { + "version": "3.0.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=4" + } + }, + "node_modules/math-intrinsics": { + "version": "1.1.0", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/media-typer": { + "version": "1.1.0", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/merge-descriptors": { + "version": "2.0.0", + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/mime-db": { + "version": "1.54.0", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/mime-types": { + "version": "3.0.2", + "license": "MIT", + "dependencies": { + "mime-db": "^1.54.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/mimic-function": { + "version": "5.0.1", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/minimatch": { + "version": "10.2.5", + "dev": true, + "license": "BlueOak-1.0.0", + "dependencies": { + "brace-expansion": "^5.0.5" + }, + "engines": { + "node": "18 || 20 || >=22" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/mkdirp": { + "version": "3.0.1", + "dev": true, + "license": "MIT", + "bin": { + "mkdirp": "dist/cjs/src/bin.js" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/mrmime": { + "version": "2.0.1", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10" + } + }, + "node_modules/ms": { + "version": "2.1.3", + "license": "MIT" + }, + "node_modules/nanoid": { + "version": "3.3.11", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "bin": { + "nanoid": "bin/nanoid.cjs" + }, + "engines": { + "node": "^10 || ^12 || ^13.7 || ^14 || >=15.0.1" + } + }, + "node_modules/natural-compare": { + "version": "1.4.0", + "dev": true, + "license": "MIT" + }, + "node_modules/negotiator": { + "version": "1.0.0", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/node-domexception": { + "version": "1.0.0", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/jimmywarting" + }, + { + "type": "github", + "url": "https://paypal.me/jimmywarting" + } + ], + "license": "MIT", + "engines": { + "node": ">=10.5.0" + } + }, + "node_modules/node-fetch": { + "version": "3.3.2", + "dev": true, + "license": "MIT", + "dependencies": { + "data-uri-to-buffer": "^4.0.0", + "fetch-blob": "^3.1.4", + "formdata-polyfill": "^4.0.10" + }, + "engines": { + "node": "^12.20.0 || ^14.13.1 || >=16.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/node-fetch" + } + }, + "node_modules/object-assign": { + "version": "4.1.1", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/object-inspect": { + "version": "1.13.4", + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/obug": { + "version": "2.1.1", + "dev": true, + "funding": [ + "https://github.com/sponsors/sxzz", + "https://opencollective.com/debug" + ], + "license": "MIT" + }, + "node_modules/on-finished": { + "version": "2.4.1", + "license": "MIT", + "dependencies": { + "ee-first": "1.1.1" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/once": { + "version": "1.4.0", + "license": "ISC", + "dependencies": { + "wrappy": "1" + } + }, + "node_modules/onetime": { + "version": "7.0.0", + "dev": true, + "license": "MIT", + "dependencies": { + "mimic-function": "^5.0.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/openai": { + "version": "6.34.0", + "dev": true, + "license": "Apache-2.0", + "bin": { + "openai": "bin/cli" + }, + "peerDependencies": { + "ws": "^8.18.0", + "zod": "^3.25 || ^4.0" + }, + "peerDependenciesMeta": { + "ws": { + "optional": true + }, + "zod": { + "optional": true + } + } + }, + "node_modules/optionator": { + "version": "0.9.4", + "dev": true, + "license": "MIT", + "dependencies": { + "deep-is": "^0.1.3", + "fast-levenshtein": "^2.0.6", + "levn": "^0.4.1", + "prelude-ls": "^1.2.1", + "type-check": "^0.4.0", + "word-wrap": "^1.2.5" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/ora": { + "version": "8.2.0", + "dev": true, + "license": "MIT", + "dependencies": { + "chalk": "^5.3.0", + "cli-cursor": "^5.0.0", + "cli-spinners": "^2.9.2", + "is-interactive": "^2.0.0", + "is-unicode-supported": "^2.0.0", + "log-symbols": "^6.0.0", + "stdin-discarder": "^0.2.2", + "string-width": "^7.2.0", + "strip-ansi": "^7.1.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/ora/node_modules/chalk": { + "version": "5.6.2", + "dev": true, + "license": "MIT", + "engines": { + "node": "^12.17.0 || ^14.13 || >=16.0.0" + }, + "funding": { + "url": "https://github.com/chalk/chalk?sponsor=1" + } + }, + "node_modules/oxc-parser": { + "version": "0.76.0", + "dev": true, + "license": "MIT", + "dependencies": { + "@oxc-project/types": "^0.76.0" + }, + "engines": { + "node": ">=20.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/Boshen" + }, + "optionalDependencies": { + "@oxc-parser/binding-android-arm64": "0.76.0", + "@oxc-parser/binding-darwin-arm64": "0.76.0", + "@oxc-parser/binding-darwin-x64": "0.76.0", + "@oxc-parser/binding-freebsd-x64": "0.76.0", + "@oxc-parser/binding-linux-arm-gnueabihf": "0.76.0", + "@oxc-parser/binding-linux-arm-musleabihf": "0.76.0", + "@oxc-parser/binding-linux-arm64-gnu": "0.76.0", + "@oxc-parser/binding-linux-arm64-musl": "0.76.0", + "@oxc-parser/binding-linux-riscv64-gnu": "0.76.0", + "@oxc-parser/binding-linux-s390x-gnu": "0.76.0", + "@oxc-parser/binding-linux-x64-gnu": "0.76.0", + "@oxc-parser/binding-linux-x64-musl": "0.76.0", + "@oxc-parser/binding-wasm32-wasi": "0.76.0", + "@oxc-parser/binding-win32-arm64-msvc": "0.76.0", + "@oxc-parser/binding-win32-x64-msvc": "0.76.0" + } + }, + "node_modules/p-limit": { + "version": "3.1.0", + "dev": true, + "license": "MIT", + "dependencies": { + "yocto-queue": "^0.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/p-locate": { + "version": "5.0.0", + "dev": true, + "license": "MIT", + "dependencies": { + "p-limit": "^3.0.2" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/p-retry": { + "version": "4.6.2", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/retry": "0.12.0", + "retry": "^0.13.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/parseurl": { + "version": "1.3.3", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/path-exists": { + "version": "4.0.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/path-expression-matcher": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/path-expression-matcher/-/path-expression-matcher-1.5.0.tgz", + "integrity": "sha512-cbrerZV+6rvdQrrD+iGMcZFEiiSrbv9Tfdkvnusy6y0x0GKBXREFg/Y65GhIfm0tnLntThhzCnfKwp1WRjeCyQ==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/NaturalIntelligence" + } + ], + "license": "MIT", + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/path-key": { + "version": "3.1.1", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/path-parse": { + "version": "1.0.7", + "dev": true, + "license": "MIT" + }, + "node_modules/path-to-regexp": { + "version": "8.4.2", + "license": "MIT", + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/pathe": { + "version": "2.0.3", + "dev": true, + "license": "MIT" + }, + "node_modules/pathval": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/pathval/-/pathval-2.0.1.tgz", + "integrity": "sha512-//nshmD55c46FuFw26xV/xFAaB5HF9Xdap7HJBBnrKdAd6/GxDBaNA1870O79+9ueg61cZLSVc+OaFlfmObYVQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 14.16" + } + }, + "node_modules/pend": { + "version": "1.2.0", + "dev": true, + "license": "MIT" + }, + "node_modules/picocolors": { + "version": "1.1.1", + "dev": true, + "license": "ISC" + }, + "node_modules/picomatch": { + "version": "4.0.4", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, + "node_modules/pify": { + "version": "2.3.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/pinkie": { + "version": "2.0.4", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/pinkie-promise": { + "version": "2.0.1", + "dev": true, + "license": "MIT", + "dependencies": { + "pinkie": "^2.0.0" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/pkce-challenge": { + "version": "5.0.1", + "license": "MIT", + "peer": true, + "engines": { + "node": ">=16.20.0" + } + }, + "node_modules/playwright": { + "version": "1.60.0", + "resolved": "https://registry.npmjs.org/playwright/-/playwright-1.60.0.tgz", + "integrity": "sha512-hheHdokM8cdqCb0lcE3s+zT4t4W+vvjpGxsZlDnikarzx8tSzMebh3UiFtgqwFwnTnjYQcsyMF8ei2mCO/tpeA==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "playwright-core": "1.60.0" + }, + "bin": { + "playwright": "cli.js" + }, + "engines": { + "node": ">=18" + }, + "optionalDependencies": { + "fsevents": "2.3.2" + } + }, + "node_modules/playwright-core": { + "version": "1.60.0", + "resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.60.0.tgz", + "integrity": "sha512-9bW6zvX/m0lEbgTKJ6YppOKx8H3VOPBMOCFh2irXFOT4BbHgrx5hPjwJYLT40Lu+4qtD36qKc/Hn56StUW57IA==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "playwright-core": "cli.js" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/pngjs": { + "version": "7.0.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14.19.0" + } + }, + "node_modules/possible-typed-array-names": { + "version": "1.1.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/postcss": { + "version": "8.5.10", + "dev": true, + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/postcss" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "dependencies": { + "nanoid": "^3.3.11", + "picocolors": "^1.1.1", + "source-map-js": "^1.2.1" + }, + "engines": { + "node": "^10 || ^12 || >=14" + } + }, + "node_modules/prelude-ls": { + "version": "1.2.1", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/prettier": { + "version": "3.8.3", + "dev": true, + "license": "MIT", + "bin": { + "prettier": "bin/prettier.cjs" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/prettier/prettier?sponsor=1" + } + }, + "node_modules/process-nextick-args": { + "version": "2.0.1", + "dev": true, + "license": "MIT" + }, + "node_modules/protobufjs": { + "version": "7.5.8", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.5.8.tgz", + "integrity": "sha512-dvpCIeLPbXZS/Ete7yLaO7RenOdken2NHKykBXbsaGxZT0UTltcarBciw+A78SRQs9iMAAVpsYA+l8b1hTePIA==", + "dev": true, + "hasInstallScript": true, + "license": "BSD-3-Clause", + "dependencies": { + "@protobufjs/aspromise": "^1.1.2", + "@protobufjs/base64": "^1.1.2", + "@protobufjs/codegen": "^2.0.5", + "@protobufjs/eventemitter": "^1.1.0", + "@protobufjs/fetch": "^1.1.0", + "@protobufjs/float": "^1.0.2", + "@protobufjs/inquire": "^1.1.1", + "@protobufjs/path": "^1.1.2", + "@protobufjs/pool": "^1.1.0", + "@protobufjs/utf8": "^1.1.1", + "@types/node": ">=13.7.0", + "long": "^5.0.0" + }, + "engines": { + "node": ">=12.0.0" + } + }, + "node_modules/proxy-addr": { + "version": "2.0.7", + "license": "MIT", + "dependencies": { + "forwarded": "0.2.0", + "ipaddr.js": "1.9.1" + }, + "engines": { + "node": ">= 0.10" + } + }, + "node_modules/punycode": { + "version": "2.3.1", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/qs": { + "version": "6.15.1", + "license": "BSD-3-Clause", + "dependencies": { + "side-channel": "^1.1.0" + }, + "engines": { + "node": ">=0.6" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/range-parser": { + "version": "1.2.1", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/raw-body": { + "version": "3.0.2", + "license": "MIT", + "dependencies": { + "bytes": "~3.1.2", + "http-errors": "~2.0.1", + "iconv-lite": "~0.7.0", + "unpipe": "~1.0.0" + }, + "engines": { + "node": ">= 0.10" + } + }, + "node_modules/readable-stream": { + "version": "2.3.8", + "dev": true, + "license": "MIT", + "dependencies": { + "core-util-is": "~1.0.0", + "inherits": "~2.0.3", + "isarray": "~1.0.0", + "process-nextick-args": "~2.0.0", + "safe-buffer": "~5.1.1", + "string_decoder": "~1.1.1", + "util-deprecate": "~1.0.1" + } + }, + "node_modules/readable-stream/node_modules/safe-buffer": { + "version": "5.1.2", + "dev": true, + "license": "MIT" + }, + "node_modules/require-from-string": { + "version": "2.0.2", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/resolve": { + "version": "1.22.12", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "is-core-module": "^2.16.1", + "path-parse": "^1.0.7", + "supports-preserve-symlinks-flag": "^1.0.0" + }, + "bin": { + "resolve": "bin/resolve" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/resolve-pkg-maps": { + "version": "1.0.0", + "license": "MIT", + "funding": { + "url": "https://github.com/privatenumber/resolve-pkg-maps?sponsor=1" + } + }, + "node_modules/restore-cursor": { + "version": "5.1.0", + "dev": true, + "license": "MIT", + "dependencies": { + "onetime": "^7.0.0", + "signal-exit": "^4.1.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/retry": { + "version": "0.13.1", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 4" + } + }, + "node_modules/rollup": { + "version": "4.60.2", + "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.60.2.tgz", + "integrity": "sha512-J9qZyW++QK/09NyN/zeO0dG/1GdGfyp9lV8ajHnRVLfo/uFsbji5mHnDgn/qYdUHyCkM2N+8VyspgZclfAh0eQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/estree": "1.0.8" + }, + "bin": { + "rollup": "dist/bin/rollup" + }, + "engines": { + "node": ">=18.0.0", + "npm": ">=8.0.0" + }, + "optionalDependencies": { + "@rollup/rollup-android-arm-eabi": "4.60.2", + "@rollup/rollup-android-arm64": "4.60.2", + "@rollup/rollup-darwin-arm64": "4.60.2", + "@rollup/rollup-darwin-x64": "4.60.2", + "@rollup/rollup-freebsd-arm64": "4.60.2", + "@rollup/rollup-freebsd-x64": "4.60.2", + "@rollup/rollup-linux-arm-gnueabihf": "4.60.2", + "@rollup/rollup-linux-arm-musleabihf": "4.60.2", + "@rollup/rollup-linux-arm64-gnu": "4.60.2", + "@rollup/rollup-linux-arm64-musl": "4.60.2", + "@rollup/rollup-linux-loong64-gnu": "4.60.2", + "@rollup/rollup-linux-loong64-musl": "4.60.2", + "@rollup/rollup-linux-ppc64-gnu": "4.60.2", + "@rollup/rollup-linux-ppc64-musl": "4.60.2", + "@rollup/rollup-linux-riscv64-gnu": "4.60.2", + "@rollup/rollup-linux-riscv64-musl": "4.60.2", + "@rollup/rollup-linux-s390x-gnu": "4.60.2", + "@rollup/rollup-linux-x64-gnu": "4.60.2", + "@rollup/rollup-linux-x64-musl": "4.60.2", + "@rollup/rollup-openbsd-x64": "4.60.2", + "@rollup/rollup-openharmony-arm64": "4.60.2", + "@rollup/rollup-win32-arm64-msvc": "4.60.2", + "@rollup/rollup-win32-ia32-msvc": "4.60.2", + "@rollup/rollup-win32-x64-gnu": "4.60.2", + "@rollup/rollup-win32-x64-msvc": "4.60.2", + "fsevents": "~2.3.2" + } + }, + "node_modules/router": { + "version": "2.2.0", + "license": "MIT", + "dependencies": { + "debug": "^4.4.0", + "depd": "^2.0.0", + "is-promise": "^4.0.0", + "parseurl": "^1.3.3", + "path-to-regexp": "^8.0.0" + }, + "engines": { + "node": ">= 18" + } + }, + "node_modules/safe-buffer": { + "version": "5.2.1", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT" + }, + "node_modules/safer-buffer": { + "version": "2.1.2", + "license": "MIT" + }, + "node_modules/seek-bzip": { + "version": "1.0.6", + "dev": true, + "license": "MIT", + "dependencies": { + "commander": "^2.8.1" + }, + "bin": { + "seek-bunzip": "bin/seek-bunzip", + "seek-table": "bin/seek-bzip-table" + } + }, + "node_modules/seek-bzip/node_modules/commander": { + "version": "2.20.3", + "dev": true, + "license": "MIT" + }, + "node_modules/semver": { + "version": "7.7.4", + "dev": true, + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/send": { + "version": "1.2.1", + "license": "MIT", + "dependencies": { + "debug": "^4.4.3", + "encodeurl": "^2.0.0", + "escape-html": "^1.0.3", + "etag": "^1.8.1", + "fresh": "^2.0.0", + "http-errors": "^2.0.1", + "mime-types": "^3.0.2", + "ms": "^2.1.3", + "on-finished": "^2.4.1", + "range-parser": "^1.2.1", + "statuses": "^2.0.2" + }, + "engines": { + "node": ">= 18" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/serve-static": { + "version": "2.2.1", + "license": "MIT", + "dependencies": { + "encodeurl": "^2.0.0", + "escape-html": "^1.0.3", + "parseurl": "^1.3.3", + "send": "^1.2.0" + }, + "engines": { + "node": ">= 18" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/set-function-length": { + "version": "1.2.2", + "dev": true, + "license": "MIT", + "dependencies": { + "define-data-property": "^1.1.4", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.4", + "gopd": "^1.0.1", + "has-property-descriptors": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/setprototypeof": { + "version": "1.2.0", + "license": "ISC" + }, + "node_modules/shebang-command": { + "version": "2.0.0", + "license": "MIT", + "dependencies": { + "shebang-regex": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/shebang-regex": { + "version": "3.0.0", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/side-channel": { + "version": "1.1.0", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "object-inspect": "^1.13.3", + "side-channel-list": "^1.0.0", + "side-channel-map": "^1.0.1", + "side-channel-weakmap": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-list": { + "version": "1.0.1", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "object-inspect": "^1.13.4" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-map": { + "version": "1.0.1", + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-weakmap": { + "version": "1.0.2", + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3", + "side-channel-map": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/siginfo": { + "version": "2.0.0", + "dev": true, + "license": "ISC" + }, + "node_modules/signal-exit": { + "version": "4.1.0", + "dev": true, + "license": "ISC", + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/sirv": { + "version": "3.0.2", + "dev": true, + "license": "MIT", + "dependencies": { + "@polka/url": "^1.0.0-next.24", + "mrmime": "^2.0.0", + "totalist": "^3.0.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/source-map": { + "version": "0.6.1", + "dev": true, + "license": "BSD-3-Clause", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/source-map-js": { + "version": "1.2.1", + "dev": true, + "license": "BSD-3-Clause", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/source-map-support": { + "version": "0.5.21", + "dev": true, + "license": "MIT", + "dependencies": { + "buffer-from": "^1.0.0", + "source-map": "^0.6.0" + } + }, + "node_modules/stackback": { + "version": "0.0.2", + "dev": true, + "license": "MIT" + }, + "node_modules/statuses": { + "version": "2.0.2", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/std-env": { + "version": "4.1.0", + "dev": true, + "license": "MIT" + }, + "node_modules/stdin-discarder": { + "version": "0.2.2", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/string_decoder": { + "version": "1.1.1", + "dev": true, + "license": "MIT", + "dependencies": { + "safe-buffer": "~5.1.0" + } + }, + "node_modules/string_decoder/node_modules/safe-buffer": { + "version": "5.1.2", + "dev": true, + "license": "MIT" + }, + "node_modules/string-width": { + "version": "7.2.0", + "dev": true, + "license": "MIT", + "dependencies": { + "emoji-regex": "^10.3.0", + "get-east-asian-width": "^1.0.0", + "strip-ansi": "^7.1.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/strip-ansi": { + "version": "7.2.0", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-regex": "^6.2.2" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/strip-ansi?sponsor=1" + } + }, + "node_modules/strip-dirs": { + "version": "2.1.0", + "dev": true, + "license": "MIT", + "dependencies": { + "is-natural-number": "^4.0.1" + } + }, + "node_modules/strip-literal": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/strip-literal/-/strip-literal-3.1.0.tgz", + "integrity": "sha512-8r3mkIM/2+PpjHoOtiAW8Rg3jJLHaV7xPwG+YRGrv6FP0wwk/toTpATxWYOW0BKdWwl82VT2tFYi5DlROa0Mxg==", + "dev": true, + "license": "MIT", + "dependencies": { + "js-tokens": "^9.0.1" + }, + "funding": { + "url": "https://github.com/sponsors/antfu" + } + }, + "node_modules/strip-literal/node_modules/js-tokens": { + "version": "9.0.1", + "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-9.0.1.tgz", + "integrity": "sha512-mxa9E9ITFOt0ban3j6L5MpjwegGz6lBQmM1IJkWeBZGcMxto50+eWdjC/52xDbS2vy0k7vIMK0Fe2wfL9OQSpQ==", + "dev": true, + "license": "MIT" + }, + "node_modules/strnum": { + "version": "2.2.3", + "resolved": "https://registry.npmjs.org/strnum/-/strnum-2.2.3.tgz", + "integrity": "sha512-oKx6RUCuHfT3oyVjtnrmn19H1SiCqgJSg+54XqURKp5aCMbrXrhLjRN9TjuwMjiYstZ0MzDrHqkGZ5dFTKd+zg==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/NaturalIntelligence" + } + ], + "license": "MIT" + }, + "node_modules/supports-color": { + "version": "7.2.0", + "dev": true, + "license": "MIT", + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/supports-preserve-symlinks-flag": { + "version": "1.0.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/tar-stream": { + "version": "1.6.2", + "dev": true, + "license": "MIT", + "dependencies": { + "bl": "^1.0.0", + "buffer-alloc": "^1.2.0", + "end-of-stream": "^1.0.0", + "fs-constants": "^1.0.0", + "readable-stream": "^2.3.0", + "to-buffer": "^1.1.1", + "xtend": "^4.0.0" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/terser": { + "version": "5.46.1", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "@jridgewell/source-map": "^0.3.3", + "acorn": "^8.15.0", + "commander": "^2.20.0", + "source-map-support": "~0.5.20" + }, + "bin": { + "terser": "bin/terser" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/terser/node_modules/commander": { + "version": "2.20.3", + "dev": true, + "license": "MIT" + }, + "node_modules/through": { + "version": "2.3.8", + "dev": true, + "license": "MIT" + }, + "node_modules/tinybench": { + "version": "2.9.0", + "dev": true, + "license": "MIT" + }, + "node_modules/tinyexec": { + "version": "1.1.1", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + } + }, + "node_modules/tinyglobby": { + "version": "0.2.16", + "dev": true, + "license": "MIT", + "dependencies": { + "fdir": "^6.5.0", + "picomatch": "^4.0.4" + }, + "engines": { + "node": ">=12.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/SuperchupuDev" + } + }, + "node_modules/tinypool": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/tinypool/-/tinypool-1.1.1.tgz", + "integrity": "sha512-Zba82s87IFq9A9XmjiX5uZA/ARWDrB03OHlq+Vw1fSdt0I+4/Kutwy8BP4Y/y/aORMo61FQ0vIb5j44vSo5Pkg==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.0.0 || >=20.0.0" + } + }, + "node_modules/tinyrainbow": { + "version": "3.1.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/tinyspy": { + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/tinyspy/-/tinyspy-4.0.4.tgz", + "integrity": "sha512-azl+t0z7pw/z958Gy9svOTuzqIk6xq+NSheJzn5MMWtWTFywIacg2wUlzKFGtt3cthx0r2SxMK0yzJOR0IES7Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/to-buffer": { + "version": "1.2.2", + "dev": true, + "license": "MIT", + "dependencies": { + "isarray": "^2.0.5", + "safe-buffer": "^5.2.1", + "typed-array-buffer": "^1.0.3" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/to-buffer/node_modules/isarray": { + "version": "2.0.5", + "dev": true, + "license": "MIT" + }, + "node_modules/toidentifier": { + "version": "1.0.1", + "license": "MIT", + "engines": { + "node": ">=0.6" + } + }, + "node_modules/totalist": { + "version": "3.0.1", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/ts-algebra": { + "version": "2.0.0", + "dev": true, + "license": "MIT" + }, + "node_modules/ts-api-utils": { + "version": "2.5.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18.12" + }, + "peerDependencies": { + "typescript": ">=4.8.4" + } + }, + "node_modules/tslib": { + "version": "2.8.1", + "license": "0BSD" + }, + "node_modules/tsx": { + "version": "4.21.0", + "license": "MIT", + "dependencies": { + "esbuild": "~0.27.0", + "get-tsconfig": "^4.7.5" + }, + "bin": { + "tsx": "dist/cli.mjs" + }, + "engines": { + "node": ">=18.0.0" + }, + "optionalDependencies": { + "fsevents": "~2.3.3" + } + }, + "node_modules/tsx/node_modules/fsevents": { + "version": "2.3.3", + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, + "node_modules/type-check": { + "version": "0.4.0", + "dev": true, + "license": "MIT", + "dependencies": { + "prelude-ls": "^1.2.1" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/type-is": { + "version": "2.0.1", + "license": "MIT", + "dependencies": { + "content-type": "^1.0.5", + "media-typer": "^1.1.0", + "mime-types": "^3.0.0" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/typed-array-buffer": { + "version": "1.0.3", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.3", + "es-errors": "^1.3.0", + "is-typed-array": "^1.1.14" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/typescript": { + "version": "5.9.3", + "dev": true, + "license": "Apache-2.0", + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + }, + "node_modules/unbzip2-stream": { + "version": "1.4.3", + "dev": true, + "license": "MIT", + "dependencies": { + "buffer": "^5.2.1", + "through": "^2.3.8" + } + }, + "node_modules/undici-types": { + "version": "6.21.0", + "dev": true, + "license": "MIT" + }, + "node_modules/unpipe": { + "version": "1.0.0", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/uri-js": { + "version": "4.4.1", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "punycode": "^2.1.0" + } + }, + "node_modules/util-deprecate": { + "version": "1.0.2", + "dev": true, + "license": "MIT" + }, + "node_modules/uuid": { + "version": "14.0.0", + "resolved": "https://registry.npmjs.org/uuid/-/uuid-14.0.0.tgz", + "integrity": "sha512-Qo+uWgilfSmAhXCMav1uYFynlQO7fMFiMVZsQqZRMIXp0O7rR7qjkj+cPvBHLgBqi960QCoo/PH2/6ZtVqKvrg==", + "funding": [ + "https://github.com/sponsors/broofa", + "https://github.com/sponsors/ctavan" + ], + "license": "MIT", + "bin": { + "uuid": "dist-node/bin/uuid" + } + }, + "node_modules/vary": { + "version": "1.1.2", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/vite": { + "version": "6.4.2", + "resolved": "https://registry.npmjs.org/vite/-/vite-6.4.2.tgz", + "integrity": "sha512-2N/55r4JDJ4gdrCvGgINMy+HH3iRpNIz8K6SFwVsA+JbQScLiC+clmAxBgwiSPgcG9U15QmvqCGWzMbqda5zGQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "esbuild": "^0.25.0", + "fdir": "^6.4.4", + "picomatch": "^4.0.2", + "postcss": "^8.5.3", + "rollup": "^4.34.9", + "tinyglobby": "^0.2.13" + }, + "bin": { + "vite": "bin/vite.js" + }, + "engines": { + "node": "^18.0.0 || ^20.0.0 || >=22.0.0" + }, + "funding": { + "url": "https://github.com/vitejs/vite?sponsor=1" + }, + "optionalDependencies": { + "fsevents": "~2.3.3" + }, + "peerDependencies": { + "@types/node": "^18.0.0 || ^20.0.0 || >=22.0.0", + "jiti": ">=1.21.0", + "less": "*", + "lightningcss": "^1.21.0", + "sass": "*", + "sass-embedded": "*", + "stylus": "*", + "sugarss": "*", + "terser": "^5.16.0", + "tsx": "^4.8.1", + "yaml": "^2.4.2" + }, + "peerDependenciesMeta": { + "@types/node": { + "optional": true + }, + "jiti": { + "optional": true + }, + "less": { + "optional": true + }, + "lightningcss": { + "optional": true + }, + "sass": { + "optional": true + }, + "sass-embedded": { + "optional": true + }, + "stylus": { + "optional": true + }, + "sugarss": { + "optional": true + }, + "terser": { + "optional": true + }, + "tsx": { + "optional": true + }, + "yaml": { + "optional": true + } + } + }, + "node_modules/vite-node": { + "version": "3.2.4", + "resolved": "https://registry.npmjs.org/vite-node/-/vite-node-3.2.4.tgz", + "integrity": "sha512-EbKSKh+bh1E1IFxeO0pg1n4dvoOTt0UDiXMd/qn++r98+jPO1xtJilvXldeuQ8giIB5IkpjCgMleHMNEsGH6pg==", + "dev": true, + "license": "MIT", + "dependencies": { + "cac": "^6.7.14", + "debug": "^4.4.1", + "es-module-lexer": "^1.7.0", + "pathe": "^2.0.3", + "vite": "^5.0.0 || ^6.0.0 || ^7.0.0-0" + }, + "bin": { + "vite-node": "vite-node.mjs" + }, + "engines": { + "node": "^18.0.0 || ^20.0.0 || >=22.0.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "node_modules/vite/node_modules/@esbuild/darwin-arm64": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.25.12.tgz", + "integrity": "sha512-N3zl+lxHCifgIlcMUP5016ESkeQjLj/959RxxNYIthIg+CQHInujFuXeWbWMgnTo4cp5XVHqFPmpyu9J65C1Yg==", + "cpu": [ + "arm64" + ], + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=18" + } + }, + "node_modules/vite/node_modules/esbuild": { + "version": "0.25.12", + "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.25.12.tgz", + "integrity": "sha512-bbPBYYrtZbkt6Os6FiTLCTFxvq4tt3JKall1vRwshA3fdVztsLAatFaZobhkBC8/BrPetoa0oksYoKXoG4ryJg==", + "dev": true, + "hasInstallScript": true, + "license": "MIT", + "bin": { + "esbuild": "bin/esbuild" + }, + "engines": { + "node": ">=18" + }, + "optionalDependencies": { + "@esbuild/aix-ppc64": "0.25.12", + "@esbuild/android-arm": "0.25.12", + "@esbuild/android-arm64": "0.25.12", + "@esbuild/android-x64": "0.25.12", + "@esbuild/darwin-arm64": "0.25.12", + "@esbuild/darwin-x64": "0.25.12", + "@esbuild/freebsd-arm64": "0.25.12", + "@esbuild/freebsd-x64": "0.25.12", + "@esbuild/linux-arm": "0.25.12", + "@esbuild/linux-arm64": "0.25.12", + "@esbuild/linux-ia32": "0.25.12", + "@esbuild/linux-loong64": "0.25.12", + "@esbuild/linux-mips64el": "0.25.12", + "@esbuild/linux-ppc64": "0.25.12", + "@esbuild/linux-riscv64": "0.25.12", + "@esbuild/linux-s390x": "0.25.12", + "@esbuild/linux-x64": "0.25.12", + "@esbuild/netbsd-arm64": "0.25.12", + "@esbuild/netbsd-x64": "0.25.12", + "@esbuild/openbsd-arm64": "0.25.12", + "@esbuild/openbsd-x64": "0.25.12", + "@esbuild/openharmony-arm64": "0.25.12", + "@esbuild/sunos-x64": "0.25.12", + "@esbuild/win32-arm64": "0.25.12", + "@esbuild/win32-ia32": "0.25.12", + "@esbuild/win32-x64": "0.25.12" + } + }, + "node_modules/vite/node_modules/fsevents": { + "version": "2.3.3", + "dev": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, + "node_modules/vitest": { + "version": "4.1.4", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/expect": "4.1.4", + "@vitest/mocker": "4.1.4", + "@vitest/pretty-format": "4.1.4", + "@vitest/runner": "4.1.4", + "@vitest/snapshot": "4.1.4", + "@vitest/spy": "4.1.4", + "@vitest/utils": "4.1.4", + "es-module-lexer": "^2.0.0", + "expect-type": "^1.3.0", + "magic-string": "^0.30.21", + "obug": "^2.1.1", + "pathe": "^2.0.3", + "picomatch": "^4.0.3", + "std-env": "^4.0.0-rc.1", + "tinybench": "^2.9.0", + "tinyexec": "^1.0.2", + "tinyglobby": "^0.2.15", + "tinyrainbow": "^3.1.0", + "vite": "^6.0.0 || ^7.0.0 || ^8.0.0", + "why-is-node-running": "^2.3.0" + }, + "bin": { + "vitest": "vitest.mjs" + }, + "engines": { + "node": "^20.0.0 || ^22.0.0 || >=24.0.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "@edge-runtime/vm": "*", + "@opentelemetry/api": "^1.9.0", + "@types/node": "^20.0.0 || ^22.0.0 || >=24.0.0", + "@vitest/browser-playwright": "4.1.4", + "@vitest/browser-preview": "4.1.4", + "@vitest/browser-webdriverio": "4.1.4", + "@vitest/coverage-istanbul": "4.1.4", + "@vitest/coverage-v8": "4.1.4", + "@vitest/ui": "4.1.4", + "happy-dom": "*", + "jsdom": "*", + "vite": "^6.0.0 || ^7.0.0 || ^8.0.0" + }, + "peerDependenciesMeta": { + "@edge-runtime/vm": { + "optional": true + }, + "@opentelemetry/api": { + "optional": true + }, + "@types/node": { + "optional": true + }, + "@vitest/browser-playwright": { + "optional": true + }, + "@vitest/browser-preview": { + "optional": true + }, + "@vitest/browser-webdriverio": { + "optional": true + }, + "@vitest/coverage-istanbul": { + "optional": true + }, + "@vitest/coverage-v8": { + "optional": true + }, + "@vitest/ui": { + "optional": true + }, + "happy-dom": { + "optional": true + }, + "jsdom": { + "optional": true + }, + "vite": { + "optional": false + } + } + }, + "node_modules/vitest/node_modules/es-module-lexer": { + "version": "2.0.0", + "dev": true, + "license": "MIT" + }, + "node_modules/web-streams-polyfill": { + "version": "3.3.3", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 8" + } + }, + "node_modules/which": { + "version": "2.0.2", + "license": "ISC", + "dependencies": { + "isexe": "^2.0.0" + }, + "bin": { + "node-which": "bin/node-which" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/which-typed-array": { + "version": "1.1.20", + "dev": true, + "license": "MIT", + "dependencies": { + "available-typed-arrays": "^1.0.7", + "call-bind": "^1.0.8", + "call-bound": "^1.0.4", + "for-each": "^0.3.5", + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "has-tostringtag": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/why-is-node-running": { + "version": "2.3.0", + "dev": true, + "license": "MIT", + "dependencies": { + "siginfo": "^2.0.0", + "stackback": "0.0.2" + }, + "bin": { + "why-is-node-running": "cli.js" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/word-wrap": { + "version": "1.2.5", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/wrappy": { + "version": "1.0.2", + "license": "ISC" + }, + "node_modules/ws": { + "version": "8.20.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": ">=5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, + "node_modules/xtend": { + "version": "4.0.2", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=0.4" + } + }, + "node_modules/yaml": { + "version": "2.8.3", + "license": "ISC", + "bin": { + "yaml": "bin.mjs" + }, + "engines": { + "node": ">= 14.6" + }, + "funding": { + "url": "https://github.com/sponsors/eemeli" + } + }, + "node_modules/yauzl": { + "version": "2.10.0", + "dev": true, + "license": "MIT", + "dependencies": { + "buffer-crc32": "~0.2.3", + "fd-slicer": "~1.1.0" + } + }, + "node_modules/yocto-queue": { + "version": "0.1.0", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/zod": { + "version": "4.3.6", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } + }, + "node_modules/zod-to-json-schema": { + "version": "3.25.2", + "license": "ISC", + "peer": true, + "peerDependencies": { + "zod": "^3.25.28 || ^4" + } + }, + "strandly": { + "name": "@strands-agents/strandly", + "version": "0.0.1", + "dependencies": { + "commander": "^14", + "tsx": "^4.21.0" + }, + "bin": { + "strandly": "src/cli.ts" + }, + "devDependencies": { + "@types/node": "^22", + "typescript": "^5.5.0" + } + }, + "strands-dev": { + "name": "@strands-agents/dev", + "version": "0.0.1", + "extraneous": true, + "dependencies": { + "commander": "^14", + "tsx": "^4.21.0" + }, + "bin": { + "strands-dev": "src/cli.ts" + }, + "devDependencies": { + "@types/node": "^22", + "typescript": "^5.5.0" + } + }, + "strands-ts": { + "name": "@strands-agents/sdk", + "version": "0.0.1-development", + "license": "Apache-2.0", + "dependencies": { + "@aws-sdk/client-bedrock-runtime": "^3.1037.0", + "@types/json-schema": "^7.0.15", + "uuid": "^14.0.0", + "yaml": "^2.8.3" + }, + "devDependencies": { + "@a2a-js/sdk": "^0.3.10", + "@ai-sdk/amazon-bedrock": "^4.0.77", + "@ai-sdk/openai": "^3.0.41", + "@ai-sdk/provider": "^3.0.0", + "@anthropic-ai/sdk": "^0.92.0", + "@aws-sdk/client-bedrock": "^3.943.0", + "@aws-sdk/client-s3": "^3.943.0", + "@aws-sdk/client-secrets-manager": "^3.943.0", + "@aws-sdk/client-sts": "^3.996.0", + "@aws-sdk/credential-providers": "^3.943.0", + "@aws/bedrock-token-generator": "^1.1.0", + "@eslint/js": "^9.39.4", + "@google/genai": "^1.40.0", + "@opentelemetry/api": "^1.9.0", + "@opentelemetry/exporter-metrics-otlp-http": "^0.214.0", + "@opentelemetry/exporter-trace-otlp-http": "^0.214.0", + "@opentelemetry/resources": "^2.6.1", + "@opentelemetry/sdk-metrics": "^2.6.1", + "@opentelemetry/sdk-trace-base": "^2.6.1", + "@opentelemetry/sdk-trace-node": "^2.6.1", + "@smithy/types": "^4.0.0", + "@types/express": "^5.0.6", + "@types/node": "^25.6.0", + "@types/uuid": "^11.0.0", + "@typescript-eslint/eslint-plugin": "^8.48.1", + "@typescript-eslint/parser": "^8.0.0", + "@vitest/browser": "^4.0.15", + "@vitest/browser-playwright": "^4.0.15", + "@vitest/coverage-v8": "^4.0.15", + "eslint": "^10.2.0", + "eslint-plugin-tsdoc": "^0.5.0", + "express": "^5.2.1", + "openai": "^6.7.0", + "playwright": "^1.60.0", + "tsx": "^4.21.0", + "typescript": "^6.0.2", + "vitest": "^4.0.8" + }, + "engines": { + "node": ">=20.0.0" + }, + "peerDependencies": { + "@a2a-js/sdk": "^0.3.10", + "@ai-sdk/provider": "^3.0.0", + "@anthropic-ai/sdk": "^0.92.0", + "@aws-sdk/client-s3": "^3.943.0", + "@aws/bedrock-token-generator": "^1.1.0", + "@google/genai": "^1.40.0", + "@modelcontextprotocol/sdk": "^1.25.2", + "@opentelemetry/api": "^1.9.0", + "@opentelemetry/exporter-metrics-otlp-http": "^0.214.0", + "@opentelemetry/exporter-trace-otlp-http": "^0.214.0", + "@opentelemetry/resources": "^2.6.1", + "@opentelemetry/sdk-metrics": "^2.6.1", + "@opentelemetry/sdk-trace-base": "^2.6.1", + "@opentelemetry/sdk-trace-node": "^2.6.1", + "@smithy/types": "^4.0.0", + "express": "^5.1.0", + "openai": "^6.7.0", + "zod": "^4.1.12" + }, + "peerDependenciesMeta": { + "@a2a-js/sdk": { + "optional": true + }, + "@ai-sdk/provider": { + "optional": true + }, + "@anthropic-ai/sdk": { + "optional": true + }, + "@aws-sdk/client-s3": { + "optional": true + }, + "@aws/bedrock-token-generator": { + "optional": true + }, + "@google/genai": { + "optional": true + }, + "@opentelemetry/exporter-metrics-otlp-http": { + "optional": true + }, + "@opentelemetry/exporter-trace-otlp-http": { + "optional": true + }, + "@opentelemetry/resources": { + "optional": true + }, + "@opentelemetry/sdk-metrics": { + "optional": true + }, + "@opentelemetry/sdk-trace-base": { + "optional": true + }, + "@opentelemetry/sdk-trace-node": { + "optional": true + }, + "@smithy/types": { + "optional": true + }, + "express": { + "optional": true + }, + "openai": { + "optional": true + } + } + }, + "strands-ts/node_modules/@opentelemetry/core": { + "version": "2.7.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/semantic-conventions": "^1.29.0" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": ">=1.0.0 <1.10.0" + } + }, + "strands-ts/node_modules/@opentelemetry/resources": { + "version": "2.7.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/core": "2.7.0", + "@opentelemetry/semantic-conventions": "^1.29.0" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": ">=1.3.0 <1.10.0" + } + }, + "strands-ts/node_modules/@opentelemetry/sdk-trace-base": { + "version": "2.7.0", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/core": "2.7.0", + "@opentelemetry/resources": "2.7.0", + "@opentelemetry/semantic-conventions": "^1.29.0" + }, + "engines": { + "node": "^18.19.0 || >=20.6.0" + }, + "peerDependencies": { + "@opentelemetry/api": ">=1.3.0 <1.10.0" + } + }, + "strands-ts/node_modules/@types/node": { + "version": "25.6.0", + "dev": true, + "license": "MIT", + "dependencies": { + "undici-types": "~7.19.0" + } + }, + "strands-ts/node_modules/typescript": { + "version": "6.0.3", + "dev": true, + "license": "Apache-2.0", + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + }, + "strands-ts/node_modules/undici-types": { + "version": "7.19.2", + "dev": true, + "license": "MIT" + }, + "strands-wasm": { + "name": "@strands-agents/wasm", + "version": "0.0.1-development", + "dependencies": { + "@aws/bedrock-token-generator": "https://github.com/pgrayy/wasm-deps/releases/download/token-gen-v1.1.0/aws-bedrock-token-generator-1.1.0.tgz", + "@strands-agents/sdk": "*", + "zod": "^4.1.12" + }, + "devDependencies": { + "@bytecodealliance/jco": "^1.16.1", + "@bytecodealliance/preview2-shim": "^0.17.9", + "@chaynabors/componentize-js": "^0.19.3", + "esbuild": "^0.27.4", + "typescript": "^6.0.2", + "vitest": "^3.2.1" + } + }, + "strands-wasm/node_modules/@vitest/expect": { + "version": "3.2.4", + "resolved": "https://registry.npmjs.org/@vitest/expect/-/expect-3.2.4.tgz", + "integrity": "sha512-Io0yyORnB6sikFlt8QW5K7slY4OjqNX9jmJQ02QDda8lyM6B5oNgVWoSoKPac8/kgnCUzuHQKrSLtu/uOqqrig==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/chai": "^5.2.2", + "@vitest/spy": "3.2.4", + "@vitest/utils": "3.2.4", + "chai": "^5.2.0", + "tinyrainbow": "^2.0.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "strands-wasm/node_modules/@vitest/mocker": { + "version": "3.2.4", + "resolved": "https://registry.npmjs.org/@vitest/mocker/-/mocker-3.2.4.tgz", + "integrity": "sha512-46ryTE9RZO/rfDd7pEqFl7etuyzekzEhUbTW3BvmeO/BcCMEgq59BKhek3dXDWgAj4oMK6OZi+vRr1wPW6qjEQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/spy": "3.2.4", + "estree-walker": "^3.0.3", + "magic-string": "^0.30.17" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "msw": "^2.4.9", + "vite": "^5.0.0 || ^6.0.0 || ^7.0.0-0" + }, + "peerDependenciesMeta": { + "msw": { + "optional": true + }, + "vite": { + "optional": true + } + } + }, + "strands-wasm/node_modules/@vitest/pretty-format": { + "version": "3.2.4", + "resolved": "https://registry.npmjs.org/@vitest/pretty-format/-/pretty-format-3.2.4.tgz", + "integrity": "sha512-IVNZik8IVRJRTr9fxlitMKeJeXFFFN0JaB9PHPGQ8NKQbGpfjlTx9zO4RefN8gp7eqjNy8nyK3NZmBzOPeIxtA==", + "dev": true, + "license": "MIT", + "dependencies": { + "tinyrainbow": "^2.0.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "strands-wasm/node_modules/@vitest/runner": { + "version": "3.2.4", + "resolved": "https://registry.npmjs.org/@vitest/runner/-/runner-3.2.4.tgz", + "integrity": "sha512-oukfKT9Mk41LreEW09vt45f8wx7DordoWUZMYdY/cyAk7w5TWkTRCNZYF7sX7n2wB7jyGAl74OxgwhPgKaqDMQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/utils": "3.2.4", + "pathe": "^2.0.3", + "strip-literal": "^3.0.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "strands-wasm/node_modules/@vitest/snapshot": { + "version": "3.2.4", + "resolved": "https://registry.npmjs.org/@vitest/snapshot/-/snapshot-3.2.4.tgz", + "integrity": "sha512-dEYtS7qQP2CjU27QBC5oUOxLE/v5eLkGqPE0ZKEIDGMs4vKWe7IjgLOeauHsR0D5YuuycGRO5oSRXnwnmA78fQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/pretty-format": "3.2.4", + "magic-string": "^0.30.17", + "pathe": "^2.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "strands-wasm/node_modules/@vitest/spy": { + "version": "3.2.4", + "resolved": "https://registry.npmjs.org/@vitest/spy/-/spy-3.2.4.tgz", + "integrity": "sha512-vAfasCOe6AIK70iP5UD11Ac4siNUNJ9i/9PZ3NKx07sG6sUxeag1LWdNrMWeKKYBLlzuK+Gn65Yd5nyL6ds+nw==", + "dev": true, + "license": "MIT", + "dependencies": { + "tinyspy": "^4.0.3" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "strands-wasm/node_modules/@vitest/utils": { + "version": "3.2.4", + "resolved": "https://registry.npmjs.org/@vitest/utils/-/utils-3.2.4.tgz", + "integrity": "sha512-fB2V0JFrQSMsCo9HiSq3Ezpdv4iYaXRG1Sx8edX3MwxfyNn83mKiGzOcH+Fkxt4MHxr3y42fQi1oeAInqgX2QA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@vitest/pretty-format": "3.2.4", + "loupe": "^3.1.4", + "tinyrainbow": "^2.0.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + } + }, + "strands-wasm/node_modules/chai": { + "version": "5.3.3", + "resolved": "https://registry.npmjs.org/chai/-/chai-5.3.3.tgz", + "integrity": "sha512-4zNhdJD/iOjSH0A05ea+Ke6MU5mmpQcbQsSOkgdaUMJ9zTlDTD/GYlwohmIE2u0gaxHYiVHEn1Fw9mZ/ktJWgw==", + "dev": true, + "license": "MIT", + "dependencies": { + "assertion-error": "^2.0.1", + "check-error": "^2.1.1", + "deep-eql": "^5.0.1", + "loupe": "^3.1.0", + "pathval": "^2.0.0" + }, + "engines": { + "node": ">=18" + } + }, + "strands-wasm/node_modules/std-env": { + "version": "3.10.0", + "resolved": "https://registry.npmjs.org/std-env/-/std-env-3.10.0.tgz", + "integrity": "sha512-5GS12FdOZNliM5mAOxFRg7Ir0pWz8MdpYm6AY6VPkGpbA7ZzmbzNcBJQ0GPvvyWgcY7QAhCgf9Uy89I03faLkg==", + "dev": true, + "license": "MIT" + }, + "strands-wasm/node_modules/tinyexec": { + "version": "0.3.2", + "resolved": "https://registry.npmjs.org/tinyexec/-/tinyexec-0.3.2.tgz", + "integrity": "sha512-KQQR9yN7R5+OSwaK0XQoj22pwHoTlgYqmUscPYoknOoWCWfj/5/ABTMRi69FrKU5ffPVh5QcFikpWJI/P1ocHA==", + "dev": true, + "license": "MIT" + }, + "strands-wasm/node_modules/tinyrainbow": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/tinyrainbow/-/tinyrainbow-2.0.0.tgz", + "integrity": "sha512-op4nsTR47R6p0vMUUoYl/a+ljLFVtlfaXkLQmqfLR1qHma1h/ysYk4hEXZ880bf2CYgTskvTa/e196Vd5dDQXw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14.0.0" + } + }, + "strands-wasm/node_modules/typescript": { + "version": "6.0.3", + "dev": true, + "license": "Apache-2.0", + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + }, + "strands-wasm/node_modules/vitest": { + "version": "3.2.4", + "resolved": "https://registry.npmjs.org/vitest/-/vitest-3.2.4.tgz", + "integrity": "sha512-LUCP5ev3GURDysTWiP47wRRUpLKMOfPh+yKTx3kVIEiu5KOMeqzpnYNsKyOoVrULivR8tLcks4+lga33Whn90A==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/chai": "^5.2.2", + "@vitest/expect": "3.2.4", + "@vitest/mocker": "3.2.4", + "@vitest/pretty-format": "^3.2.4", + "@vitest/runner": "3.2.4", + "@vitest/snapshot": "3.2.4", + "@vitest/spy": "3.2.4", + "@vitest/utils": "3.2.4", + "chai": "^5.2.0", + "debug": "^4.4.1", + "expect-type": "^1.2.1", + "magic-string": "^0.30.17", + "pathe": "^2.0.3", + "picomatch": "^4.0.2", + "std-env": "^3.9.0", + "tinybench": "^2.9.0", + "tinyexec": "^0.3.2", + "tinyglobby": "^0.2.14", + "tinypool": "^1.1.1", + "tinyrainbow": "^2.0.0", + "vite": "^5.0.0 || ^6.0.0 || ^7.0.0-0", + "vite-node": "3.2.4", + "why-is-node-running": "^2.3.0" + }, + "bin": { + "vitest": "vitest.mjs" + }, + "engines": { + "node": "^18.0.0 || ^20.0.0 || >=22.0.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "@edge-runtime/vm": "*", + "@types/debug": "^4.1.12", + "@types/node": "^18.0.0 || ^20.0.0 || >=22.0.0", + "@vitest/browser": "3.2.4", + "@vitest/ui": "3.2.4", + "happy-dom": "*", + "jsdom": "*" + }, + "peerDependenciesMeta": { + "@edge-runtime/vm": { + "optional": true + }, + "@types/debug": { + "optional": true + }, + "@types/node": { + "optional": true + }, + "@vitest/browser": { + "optional": true + }, + "@vitest/ui": { + "optional": true + }, + "happy-dom": { + "optional": true + }, + "jsdom": { + "optional": true + } + } + } + } +} diff --git a/package.json b/package.json new file mode 100644 index 0000000000..c50c43d928 --- /dev/null +++ b/package.json @@ -0,0 +1,33 @@ +{ + "name": "strands", + "version": "0.0.0", + "private": true, + "workspaces": [ + "strandly", + "strands-ts", + "strands-wasm" + ], + "devDependencies": { + "husky": "^9.1.7", + "prettier": "^3.7.4" + }, + "scripts": { + "dev": "strandly", + "prepare": "husky && npm run build", + "build": "npm run build -w strands-ts", + "test": "npm run test -w strands-ts", + "test:coverage": "npm run test:coverage -w strands-ts", + "test:all": "npm run test:all -w strands-ts", + "test:all:coverage": "npm run test:all:coverage -w strands-ts", + "test:integ": "npm run test:integ -w strands-ts", + "test:integ:all": "npm run test:integ:all -w strands-ts", + "test:browser:install": "npm run test:browser:install -w strands-ts", + "test:package": "npm run test:package -w strands-ts", + "lint": "npm run lint -w strands-ts", + "format": "npm run format -w strands-ts", + "format:check": "npm run format:check -w strands-ts", + "type-check": "npm run type-check -w strands-ts", + "check": "npm run check -w strands-ts", + "check:browser-bundle": "npm run check:browser-bundle -w strands-ts" + } +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..81c1aef125 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,62 @@ +[project] +name = "strands-monorepo-tools" +version = "0.0.0" +description = "Shared Python tooling for the Strands monorepo. Not published." +requires-python = ">=3.10" +dependencies = [ + # Build-time codegen: parses WIT and emits Python bindings. + "componentize-py>=0.23.0,<1.0.0", + # Linter/formatter used by ``strandly check --py`` and ``strandly fmt``. + "ruff>=0.13.0,<0.15.0", + # Type checker used by ``strandly check --py``. + "pyright>=1.1.400", + # Test runner for the shared venv. + "pytest>=9.0.3", + "pytest-asyncio>=1.3.0", + # Optional runtime extras used by strands-py-wasm integration tests. + "pydantic>=2.13.3", + "docstring-parser>=0.18.0", + "boto3>=1.43.2", + "tenacity>=9.1.4", +] + +[tool.setuptools] +# We don't ship a distribution. Suppress the auto-discovery warning. +packages = [] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +cache_dir = ".pytest_cache" + + +# Ruff config lives here because ruff walks up from the file it's linting +# and the monorepo has a single style. strands-py-wasm/pyproject.toml does not +# carry its own ruff table. +[tool.ruff] +line-length = 120 +include = ["strands-py-wasm/src/**/*.py", "strandly/scripts/**/*.py"] +# _generated.py is machine-written; neither lint nor format enforce rules on it. +extend-exclude = ["strands-py-wasm/src/strands/_generated.py"] + +[tool.ruff.lint] +select = [ + "B", # flake8-bugbear + "E", # pycodestyle + "F", # pyflakes + "I", # isort + "UP", # pyupgrade +] + + +# Pyright config lives here for the same reason as ruff: one rule for +# the whole monorepo. Pyright walks up to find the nearest pyproject. +# `exclude` is omitted so pyright inherits its built-in defaults +# (**/node_modules, **/__pycache__, **/.*); setting it would replace +# those, not merge. +[tool.pyright] +include = ["strands-py-wasm/src/strands"] +pythonVersion = "3.10" +pythonPlatform = "All" +typeCheckingMode = "standard" +reportMissingTypeStubs = false +reportMissingImports = "warning" diff --git a/strandly/package.json b/strandly/package.json new file mode 100644 index 0000000000..27baf2160b --- /dev/null +++ b/strandly/package.json @@ -0,0 +1,20 @@ +{ + "name": "@strands-agents/strandly", + "version": "0.0.1", + "private": true, + "type": "module", + "bin": { + "strandly": "./src/cli.ts" + }, + "scripts": { + "type-check": "tsc --noEmit" + }, + "dependencies": { + "commander": "^14", + "tsx": "^4.21.0" + }, + "devDependencies": { + "@types/node": "^22", + "typescript": "^5.5.0" + } +} diff --git a/strandly/src/cli.ts b/strandly/src/cli.ts new file mode 100755 index 0000000000..843886e7de --- /dev/null +++ b/strandly/src/cli.ts @@ -0,0 +1,272 @@ +#!/usr/bin/env tsx + +import { execSync } from 'node:child_process' +import { existsSync, readdirSync, readFileSync, writeFileSync } from 'node:fs' +import { join, resolve } from 'node:path' +import { program } from 'commander' + +const ROOT = resolve(import.meta.dirname, '../..') +const PY = `${ROOT}/strands-py-wasm` +const VENV = `${ROOT}/.venv` + +process.env.PYTHONPYCACHEPREFIX ??= `${ROOT}/.pycache` + +program.name('strandly').description( + `Strands monorepo development CLI + +Build pipeline (each step feeds the next): + wit/agent.wit -> strands-ts -> strands-wasm -> strands-py-wasm + +Most commands accept layer flags (--ts, --wasm, --py). +No flags = run all layers.` +) + +program + .command('setup') + .description('Install toolchains and dependencies') + .option('--node', 'npm install') + .option('--python', 'Create venv and install ruff') + .action((opts) => setup(opts)) + +program + .command('build') + .description('Compile one or more layers') + .option('--ts', 'TypeScript SDK') + .option('--wasm', 'WASM component (rebuilds TS first)') + .option('--py', 'Python package') + .action((opts) => build(opts)) + +program + .command('test') + .description('Run tests') + .option('--py', 'Python tests') + .option('--ts', 'TypeScript tests') + .argument('[file]', 'Specific Python test file') + .action((file, opts) => test({ ...opts, file })) + +program + .command('check') + .description('Lint and type-check without building') + .option('--ts', 'TypeScript type-check') + .option('--wasm', 'WASM bridge type-check') + .option('--py', 'Python ruff') + .action((opts) => check(opts)) + +program + .command('fmt') + .description('Format all code') + .option('--check', 'Fail if anything would change') + .action((opts) => fmt(opts)) + +program + .command('generate') + .description('Regenerate type declarations from WIT') + .option('--check', 'Fail if generated files are out of date') + .action((opts) => generate(opts)) + +program + .command('example') + .description('Run an example by name') + .argument('', 'Example name') + .option('--py', 'Run a Python example') + .option('--ts', 'Run a TypeScript example') + .action((name, opts) => { + if (opts.py) py(`python examples/${name}.py`) + else if (opts.ts) run('npm start', { cwd: `${ROOT}/strands-ts/examples/${name}` }) + }) + +program + .command('clean') + .description('Remove all build artifacts') + .action(() => clean()) + +program + .command('ci') + .description('Full CI pipeline') + .action(() => { + generate({ check: true }) + fmt({ check: true }) + check() + build() + test() + }) + +program + .command('bootstrap') + .description('First-time setup, generate, build, and test') + .action(() => { + setup() + linkCli() + generate() + build() + test() + }) + +program + .command('link') + .description('Install `strandly` on PATH as a live symlink to this repo') + .action(() => linkCli()) + +program + .command('rebuild') + .description('Clean rebuild from scratch') + .action(() => { + clean() + generate() + build() + }) + +const VALIDATE_LAYERS = ['wit', 'ts', 'ts-api', 'wasm', 'py'] as const + +program + .command('validate') + .description('Validate changes to a specific layer') + .argument('', `Layer: ${VALIDATE_LAYERS.join(', ')}`) + .action((layer: string) => { + switch (layer) { + case 'wit': + generate() + build() + test() + break + case 'ts': + build({ ts: true }) + test({ ts: true }) + break + case 'ts-api': + build({ wasm: true }) + test({ ts: true }) + break + case 'wasm': + build({ wasm: true }) + check({ wasm: true }) + break + case 'py': + check({ py: true }) + test({ py: true }) + break + default: + console.error(`Unknown layer: ${layer}\nValid layers: ${VALIDATE_LAYERS.join(', ')}`) + process.exit(1) + } + }) + +program.parse() + +function run(cmd: string, opts?: { cwd?: string; env?: Record }): void { + try { + execSync(cmd, { + stdio: 'inherit', + cwd: opts?.cwd ?? ROOT, + env: opts?.env ? { ...process.env, ...opts.env } : undefined, + }) + } catch (e: unknown) { + const status = (e as { status?: number }).status ?? 1 + console.error(`\nfailed: ${cmd} (exit ${status})`) + process.exit(status) + } +} + +/** Run a command with the repo-root venv on PATH. ``cwd`` defaults to + * strands-py-wasm because most Python commands (pytest, ruff) act on that + * package's source, but callers can override. */ +function py(cmd: string, opts?: { cwd?: string }): void { + run(cmd, { + cwd: opts?.cwd ?? PY, + env: { VIRTUAL_ENV: VENV, PATH: `${VENV}/bin:${process.env.PATH}` }, + }) +} + +function setup(opts?: { node?: boolean; python?: boolean }): void { + const all = !opts?.node && !opts?.python + if (all || opts?.node) run('npm install') + if (all || opts?.python) { + run('python3 -m venv .venv', { cwd: ROOT }) + run(`${VENV}/bin/pip install -e .`, { cwd: ROOT }) + run(`${VENV}/bin/pip install -e strands-py-wasm/`, { cwd: ROOT }) + } +} + +function linkCli(): void { + run('npm link -w strandly') +} + +function build(opts?: { ts?: boolean; wasm?: boolean; py?: boolean }): void { + const all = !opts?.ts && !opts?.wasm && !opts?.py + + if (all || opts?.ts || opts?.wasm) run('npm install') + if (all || opts?.ts) run('npm run build -w strands-ts') + if (all || opts?.wasm) { + if (!all && !opts?.ts) run('npm run build -w strands-ts') + run('npm run build -w strands-wasm') + } +} + +function test(opts?: { py?: boolean; ts?: boolean; file?: string }): void { + const all = !opts?.py && !opts?.ts + if (all || opts?.py) py(opts?.file ? `pytest ${opts.file} -v` : 'pytest') + if (all || opts?.ts) run('npm test -w strands-ts') +} + +function check(opts?: { ts?: boolean; wasm?: boolean; py?: boolean }): void { + const all = !opts?.ts && !opts?.wasm && !opts?.py + if (all || opts?.py) py('ruff check src/strands') + if (all || opts?.ts) run('npm run type-check -w strands-ts') + if (all || opts?.wasm) run('npm run type-check -w strands-wasm') +} + +function fmt(opts?: { check?: boolean }): void { + const flag = opts?.check ? ' --check' : '' + run( + `npx prettier ${opts?.check ? '--check' : '--write'} 'strands-wasm/**/*.ts' 'strands-ts/**/*.ts' --ignore-path .gitignore` + ) + py(`ruff format${flag} src/strands`) +} + +function generate(opts?: { check?: boolean }): void { + run('npm install') + run('npx jco guest-types wit --name strands:agent --world-name agent --out-dir strands-ts/generated', { cwd: ROOT }) + run('npx jco guest-types wit --name strands:agent --world-name agent --out-dir strands-wasm/generated', { cwd: ROOT }) + + // Tag generated TS/WASM type declarations. + for (const dir of ['strands-wasm/generated', 'strands-ts/generated']) { + for (const file of readdirSync(join(ROOT, dir), { recursive: true, encoding: 'utf-8' }).filter((f) => + f.endsWith('.d.ts') + )) { + const path = join(ROOT, dir, file) + const content = readFileSync(path, 'utf-8') + if (!content.startsWith('// @generated')) { + writeFileSync(path, `// @generated from wit/agent.wit -- do not edit\n\n${content}`) + } + } + } + + // Generate Python types from WIT via wasmtime-py's component bindgen. + run(`${VENV}/bin/python -m wasmtime.component.bindgen wit -o strands-py-wasm/src/strands/_generated.py`) + + // Ensure TS + WASM are built first. + if (!existsSync(join(ROOT, 'strands-wasm/dist/strands-agent.wasm'))) { + build({ ts: true, wasm: true }) + } + + if (opts?.check) { + try { + execSync('git diff --quiet -- strands-wasm/generated/ strands-ts/generated/ strands-py-wasm/src/strands/_generated.py', { + cwd: ROOT, + }) + } catch { + console.error("error: generated files are out of date -- run 'strandly generate' and commit") + run('git diff --stat -- strands-wasm/generated/ strands-ts/generated/ strands-py-wasm/src/strands/_generated.py') + process.exit(1) + } + } +} + +function clean(): void { + try { + run('npm run clean --workspaces') + } catch (e) { + console.warn('workspace clean failed (continuing):', (e as Error).message) + } + run('rm -rf .venv strands-py-wasm/target') +} diff --git a/strandly/tsconfig.json b/strandly/tsconfig.json new file mode 100644 index 0000000000..6d68507881 --- /dev/null +++ b/strandly/tsconfig.json @@ -0,0 +1,12 @@ +{ + "compilerOptions": { + "target": "ES2023", + "module": "NodeNext", + "moduleResolution": "NodeNext", + "types": ["node"], + "strict": true, + "noEmit": true, + "skipLibCheck": true + }, + "include": ["src"] +} diff --git a/strands-py-wasm/README.md b/strands-py-wasm/README.md new file mode 100644 index 0000000000..01052f63d1 --- /dev/null +++ b/strands-py-wasm/README.md @@ -0,0 +1,3 @@ +# strands-py-wasm + +Strands Python SDK 2.0 stub. diff --git a/strands-py-wasm/pyproject.toml b/strands-py-wasm/pyproject.toml new file mode 100644 index 0000000000..f2a025cd2b --- /dev/null +++ b/strands-py-wasm/pyproject.toml @@ -0,0 +1,51 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + + +[project] +name = "strands-agents" +version = "2.0.0a1" +description = "A model-driven approach to building AI agents in just a few lines of code" +requires-python = ">=3.10" +license = {text = "Apache-2.0"} +license-files = ["LICENSE.APACHE", "LICENSE.MIT"] +authors = [ + {name = "AWS", email = "opensource@amazon.com"}, +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", +] +dependencies = [ + # Imports as `wasmtime`. Pinned to pgrayy fork until upstream PRs land. + "pgrayy-wasmtime>=46.0.6,<47.0.0", +] + + +[project.optional-dependencies] +pydantic = ["pydantic>=2.4.0,<3.0.0"] + + +[project.urls] +Homepage = "https://github.com/strands-agents/sdk-typescript" +"Bug Tracker" = "https://github.com/strands-agents/sdk-typescript/issues" +Documentation = "https://strandsagents.com" + + +[tool.hatch.build.targets.wheel] +packages = ["src/strands"] + +[tool.hatch.build.targets.wheel.force-include] +# Bundle the wasm into the wheel; _runtime.py finds it via package data. +"../strands-wasm/dist/strands-agent.wasm" = "strands/strands-agent.wasm" diff --git a/strands-py-wasm/src/strands/__init__.py b/strands-py-wasm/src/strands/__init__.py new file mode 100644 index 0000000000..424311a1a8 --- /dev/null +++ b/strands-py-wasm/src/strands/__init__.py @@ -0,0 +1,1137 @@ +"""Strands Agents SDK Python surface. + +Wire types live in :mod:`strands._generated`, auto-generated by ``bindgen``. +Those classes are accepted directly by wasmtime-py at the FFI boundary: +kebab-case record attributes, ``Variant(tag, payload)`` for tagged variants, +raw payloads for untagged ones. + +This module wraps the wire types with ergonomic Python helpers: builders +that take snake_case kwargs, factories that fill variant arms, and SDK-level +event classes that lift wire-shape ``StreamEvent`` values into a typed Python +class hierarchy so users can dispatch with ``isinstance``. +""" + +from __future__ import annotations + +import asyncio +import inspect +import json +import typing +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable +from dataclasses import asdict, is_dataclass +from typing import Any, Protocol, TypeVar, get_type_hints, runtime_checkable + +from strands import _generated as _t +from strands._generated import * # noqa: F401,F403 re-export every generated type + + +class StrandsError(Exception): + """Base class for all SDK-raised errors.""" + + +class _ModelErrorBase(StrandsError): + """Base for errors surfaced by a model provider.""" + + +class ContextWindowOverflowError(_ModelErrorBase): + """Input exceeded the model's context window and no recovery was possible.""" + + +class MaxTokensError(_ModelErrorBase): + """Model stopped generating because it hit the max-tokens budget.""" + + def __init__(self, message: str, partial_message: _t.Message | None = None) -> None: + super().__init__(message) + self.partial_message = partial_message + + +class ModelThrottledError(_ModelErrorBase): + """Model provider throttled the request. Hook into retry to recover.""" + + +class ProviderTokenCountError(_ModelErrorBase): + """Provider-native token counting failed; base heuristic should run instead.""" + + +class ToolValidationError(StrandsError): + """A tool failed validation at registration or invocation time.""" + + +class JsonValidationError(StrandsError): + """A value could not be serialized to JSON.""" + + +class StructuredOutputError(StrandsError): + """Model refused to use the structured-output tool even after being forced.""" + + +class ConcurrentInvocationError(StrandsError): + """Agent is already processing an invocation; concurrent calls are not allowed.""" + + +class SessionError(StrandsError): + """Session storage read/write failed.""" + + +# Inputs Message / PromptInput accept. Plain strings auto-wrap as text. +_ContentInput = ( + str + | _t.TextBlock + | _t.JsonBlock + | _t.ToolUseBlock + | _t.ToolResultBlock + | _t.ReasoningBlock + | _t.CachePointBlock + | _t.ImageBlock + | _t.VideoBlock + | _t.DocumentBlock + | _t.CitationsBlock + | _t.InterruptResponseBlock + | Any # already-built ContentBlock variants are passed through +) + + +def _as_content_block(item: Any) -> Any: + """Wrap any accepted content shape as a ``ContentBlock`` variant arm.""" + if isinstance(item, str): + return _t.ContentBlock_Text(_t.TextBlock(text=item)) + if isinstance(item, _t.TextBlock): + return _t.ContentBlock_Text(item) + if isinstance(item, _t.JsonBlock): + return _t.ContentBlock_Json(item) + if isinstance(item, _t.ToolUseBlock): + return _t.ContentBlock_ToolUse(item) + if isinstance(item, _t.ToolResultBlock): + return _t.ContentBlock_ToolResult(item) + if isinstance(item, _t.ReasoningBlock): + return _t.ContentBlock_Reasoning(item) + if isinstance(item, _t.CachePointBlock): + return _t.ContentBlock_CachePoint(item) + if isinstance(item, _t.ImageBlock): + return _t.ContentBlock_Image(item) + if isinstance(item, _t.VideoBlock): + return _t.ContentBlock_Video(item) + if isinstance(item, _t.DocumentBlock): + return _t.ContentBlock_Document(item) + if isinstance(item, _t.CitationsBlock): + return _t.ContentBlock_Citations(item) + if isinstance(item, _t.InterruptResponseBlock): + return _t.ContentBlock_InterruptResponse(item) + return item # already a ContentBlock variant + + +class ImageBlock(_t.ImageBlock): + def __init__( + self, + *, + format: str, + bytes: bytes | None = None, + url: str | None = None, + s3: _t.S3Location | None = None, + ) -> None: + provided = [x for x in (bytes, url, s3) if x is not None] + if len(provided) != 1: + raise ValueError("ImageBlock requires exactly one of bytes, url, or s3") + if bytes is not None: + source = _t.ImageSource_Bytes(bytes) + elif url is not None: + source = _t.ImageSource_Url(url) + else: + assert s3 is not None + source = _t.ImageSource_S3(s3) + super().__init__(format=format, source=source) + + +class VideoBlock(_t.VideoBlock): + def __init__( + self, + *, + format: str, + bytes: bytes | None = None, + s3: _t.S3Location | None = None, + ) -> None: + if (bytes is None) == (s3 is None): + raise ValueError("VideoBlock requires exactly one of bytes or s3") + source = _t.VideoSource_Bytes(bytes) if bytes is not None else _t.VideoSource_S3(s3) + super().__init__(format=format, source=source) + + +class DocumentBlock(_t.DocumentBlock): + def __init__( + self, + *, + name: str, + format: str, + bytes: bytes | None = None, + text: str | None = None, + content: list[_t.TextBlock] | None = None, + s3: _t.S3Location | None = None, + citations: bool = False, + context: str | None = None, + ) -> None: + provided = [x for x in (bytes, text, content, s3) if x is not None] + if len(provided) != 1: + raise ValueError("DocumentBlock requires exactly one of bytes, text, content, or s3") + if bytes is not None: + source = _t.DocumentSource_Bytes(bytes) + elif text is not None: + source = _t.DocumentSource_Text(text) + elif content is not None: + source = _t.DocumentSource_Content(content) + else: + assert s3 is not None + source = _t.DocumentSource_S3(s3) + super().__init__( + name=name, + format=format, + source=source, + citations=_t.DocumentCitationsConfig(enabled=citations) if citations else None, + context=context, + ) + + +class InterruptResponseBlock(_t.InterruptResponseBlock): + def __init__(self, *, interrupt_id: str, response: Any) -> None: + payload = response if isinstance(response, str) else json.dumps(response) + super().__init__(interrupt_id=interrupt_id, response=payload) + + +class Message(_t.Message): + def __init__( + self, + *, + role: _t.Role, + content: Iterable[Any], + metadata: _t.MessageMetadata | None = None, + ) -> None: + super().__init__( + role=role, + content=[_as_content_block(c) for c in content], + metadata=metadata, + ) + + @classmethod + def user(cls, *content: Any, metadata: _t.MessageMetadata | None = None) -> Message: + return cls(role=_t.Role.USER, content=content, metadata=metadata) + + @classmethod + def assistant(cls, *content: Any, metadata: _t.MessageMetadata | None = None) -> Message: + return cls(role=_t.Role.ASSISTANT, content=content, metadata=metadata) + + +def _extras_to_json(extras: dict[str, Any] | None) -> str | None: + return json.dumps(extras) if extras else None + + +def BedrockModel( + model_id: str = "us.anthropic.claude-opus-4-7-v1:0", + *, + region: str | None = None, + access_key_id: str | None = None, + secret_access_key: str | None = None, + session_token: str | None = None, + **extras: Any, +) -> Any: + """Build a Bedrock ``model-config`` value.""" + # The wasm bundle links the AWS SDK browser build, which has no credential + # chain. Resolve via botocore so users get the same behavior they'd get + # from any other Python AWS app (env vars, ~/.aws, SSO, IMDS, etc.). + if access_key_id is None and secret_access_key is None: + try: + import botocore.session + + creds = botocore.session.Session().get_credentials() + if creds is not None: + frozen = creds.get_frozen_credentials() + access_key_id = frozen.access_key + secret_access_key = frozen.secret_key + session_token = frozen.token + except ImportError: + pass + + return _t.ModelConfig_Bedrock( + _t.BedrockConfig( + model_id=model_id, + region=region, + access_key_id=access_key_id, + secret_access_key=secret_access_key, + session_token=session_token, + additional_config=_extras_to_json(extras), + ) + ) + + +def AnthropicModel(model_id: str | None = None, *, api_key: str | None = None, **extras: Any) -> Any: + return _t.ModelConfig_Anthropic( + _t.AnthropicConfig(model_id=model_id, api_key=api_key, additional_config=_extras_to_json(extras)) + ) + + +def OpenAIModel(model_id: str | None = None, *, api_key: str | None = None, **extras: Any) -> Any: + return _t.ModelConfig_Openai( + _t.OpenaiConfig(model_id=model_id, api_key=api_key, additional_config=_extras_to_json(extras)) + ) + + +def GoogleModel(model_id: str | None = None, *, api_key: str | None = None, **extras: Any) -> Any: + return _t.ModelConfig_Gemini( + _t.GeminiConfig(model_id=model_id, api_key=api_key, additional_config=_extras_to_json(extras)) + ) + + +def CustomModel( + provider_id: str, + *, + model_id: str | None = None, + stateful: bool = False, + **extras: Any, +) -> Any: + """Host-implemented provider. Pair with a ``model-provider`` callback.""" + return _t.ModelConfig_Custom( + _t.CustomModelConfig( + provider_id=provider_id, + model_id=model_id, + additional_config=_extras_to_json(extras), + stateful=stateful, + ) + ) + + +class PydanticTool: + """Tool whose input schema is derived from a pydantic ``BaseModel``.""" + + def __init__( + self, + *, + name: str, + description: str, + input_model: type, + func: Callable[..., Any], + ) -> None: + if not hasattr(input_model, "model_json_schema") or not hasattr(input_model, "model_validate"): + raise TypeError(f"input_model must be a pydantic BaseModel subclass; got {input_model!r}") + self.name = name + self.description = description + self._input_model = input_model + self.input_schema = input_model.model_json_schema() + self.func = func + + def to_spec(self) -> _t.ToolSpec: + return _t.ToolSpec( + name=self.name, + description=self.description, + input_schema=json.dumps(self.input_schema), + ) + + def invoke(self, raw_input: str) -> list[Any]: + payload = json.loads(raw_input) if raw_input else {} + validated = self._input_model.model_validate(payload) + return _coerce_tool_result(self.func(validated)) + + +class Tool: + """Registered tool: spec plus Python callable.""" + + def __init__( + self, + *, + name: str, + description: str, + input_schema: dict[str, Any], + func: Callable[..., Any], + ) -> None: + self.name = name + self.description = description + self.input_schema = input_schema + self.func = func + + def to_spec(self) -> _t.ToolSpec: + return _t.ToolSpec( + name=self.name, + description=self.description, + input_schema=json.dumps(self.input_schema), + ) + + def invoke(self, raw_input: str) -> list[Any]: + kwargs = json.loads(raw_input) if raw_input else {} + return _coerce_tool_result(self.func(**kwargs)) + + +def _coerce_tool_result(result: Any) -> list[Any]: + if isinstance(result, str): + return [_t.ToolResultContent_Text(_t.TextBlock(text=result))] + if isinstance(result, _t.TextBlock): + return [_t.ToolResultContent_Text(result)] + if isinstance(result, _t.JsonBlock): + return [_t.ToolResultContent_Json(result)] + if isinstance(result, dict): + return [_t.ToolResultContent_Json(_t.JsonBlock(json=json.dumps(result)))] + if is_dataclass(result) and not isinstance(result, type): + return [_t.ToolResultContent_Json(_t.JsonBlock(json=json.dumps(asdict(result))))] + if isinstance(result, list): + return result + return [_t.ToolResultContent_Text(_t.TextBlock(text=str(result)))] + + +def _py_type_to_schema(py_type: Any) -> dict[str, Any]: + import types + + origin = typing.get_origin(py_type) + + # Strip Annotated[T, ...] -- only the runtime type matters for the schema. + if origin is typing.Annotated: + return _py_type_to_schema(typing.get_args(py_type)[0]) + + # Optional[T] / Union[T, None] / T | None: emit T's schema and mark nullable. + if origin is typing.Union or origin is types.UnionType: + args = typing.get_args(py_type) + non_none = [a for a in args if a is not type(None)] + nullable = len(non_none) != len(args) + if len(non_none) == 1: + schema = _py_type_to_schema(non_none[0]) + if nullable: + schema = {**schema, "nullable": True} + return schema + return {} # heterogeneous union -- caller should supply input_schema + + if py_type is str: + return {"type": "string"} + if py_type is int: + return {"type": "integer"} + if py_type is float: + return {"type": "number"} + if py_type is bool: + return {"type": "boolean"} + if origin is list: + args = typing.get_args(py_type) + return {"type": "array", "items": _py_type_to_schema(args[0]) if args else {}} + if origin is dict: + return {"type": "object"} + if origin is typing.Literal: + return {"enum": list(typing.get_args(py_type))} + return {} + + +def tool( + func: Callable[..., Any] | None = None, + *, + name: str | None = None, + description: str | None = None, +) -> Any: + """Decorator that turns a Python function into a :class:`Tool`.""" + + def wrap(f: Callable[..., Any]) -> Tool: + hints = get_type_hints(f) + sig = inspect.signature(f) + properties: dict[str, Any] = {} + required: list[str] = [] + for param_name, param in sig.parameters.items(): + properties[param_name] = _py_type_to_schema(hints.get(param_name, str)) + if param.default is inspect.Parameter.empty: + required.append(param_name) + schema: dict[str, Any] = {"type": "object", "properties": properties} + if required: + schema["required"] = required + return Tool( + name=name or f.__name__, + description=description or (f.__doc__ or "").strip() or f.__name__, + input_schema=schema, + func=f, + ) + + return wrap(func) if func is not None else wrap + + +def NullConversationManager() -> Any: + """No management. History grows without bound.""" + return _t.ConversationManagerConfig_None() + + +def SlidingWindowConversationManager( + *, window_size: int = 40, should_truncate_results: bool = True +) -> Any: + return _t.ConversationManagerConfig_SlidingWindow( + _t.SlidingWindowConfig(window_size=window_size, should_truncate_results=should_truncate_results) + ) + + +def SummarizingConversationManager( + *, + summary_ratio: float = 0.3, + preserve_recent_messages: int = 10, + summarization_system_prompt: str | None = None, + summarization_model: Any = None, +) -> Any: + return _t.ConversationManagerConfig_Summarizing( + _t.SummarizingConfig( + summary_ratio=summary_ratio, + preserve_recent_messages=preserve_recent_messages, + summarization_system_prompt=summarization_system_prompt, + summarization_model=summarization_model, + ) + ) + + +def _seconds_to_ns(seconds: float) -> int: + return int(seconds * 1_000_000_000) + + +def _optional_ns(seconds: float | None) -> int | None: + return None if seconds is None else _seconds_to_ns(seconds) + + +def ConstantBackoff(*, delay: float = 1.0) -> Any: + return _t.BackoffStrategy_Constant(_t.ConstantBackoffConfig(delay=_seconds_to_ns(delay))) + + +def LinearBackoff( + *, + base: float = 1.0, + max: float = 30.0, + jitter: _t.JitterKind = _t.JitterKind.FULL, +) -> Any: + return _t.BackoffStrategy_Linear( + _t.LinearBackoffConfig(base=_seconds_to_ns(base), max=_seconds_to_ns(max), jitter=jitter) + ) + + +def ExponentialBackoff( + *, + base: float = 1.0, + max: float = 30.0, + factor: float = 2.0, + jitter: _t.JitterKind = _t.JitterKind.FULL, +) -> Any: + return _t.BackoffStrategy_Exponential( + _t.ExponentialBackoffConfig( + base=_seconds_to_ns(base), + max=_seconds_to_ns(max), + factor=factor, + jitter=jitter, + ) + ) + + +class ModelRetryStrategy(_t.ModelRetryStrategy): + def __init__( + self, + *, + max_attempts: int = 6, + backoff: Any = None, + total_budget: float | None = None, + ) -> None: + super().__init__( + max_attempts=max_attempts, + backoff=backoff if backoff is not None else ExponentialBackoff(), + total_budget=_optional_ns(total_budget), + ) + + +def FileStorage(base_dir: str) -> Any: + return _t.StorageConfig_File(_t.FileStorageConfig(base_dir=base_dir)) + + +def S3Storage(*, bucket: str, region: str | None = None, prefix: str | None = None) -> Any: + return _t.StorageConfig_S3(_t.S3StorageConfig(bucket=bucket, region=region, prefix=prefix)) + + +def CustomStorage(backend_id: str) -> Any: + """Host-implemented backend. Pair with a ``snapshot-storage`` handler.""" + return _t.StorageConfig_Custom(_t.CustomStorageConfig(backend_id=backend_id)) + + +class SessionManager(_t.SessionConfig): + """Attach session persistence to an agent.""" + + def __init__( + self, + *, + session_id: str, + storage: Any, + save_latest: Any = None, + ) -> None: + super().__init__(session_id=session_id, storage=storage, save_latest=save_latest) + + +def _coerce_nested_config(value: Any) -> str: + if isinstance(value, str): + return value + return json.dumps(value, default=_json_default) + + +def _json_default(obj: Any) -> Any: + if is_dataclass(obj) and not isinstance(obj, type): + return asdict(obj) + if hasattr(obj, "__dict__"): + return {k: v for k, v in obj.__dict__.items() if not k.startswith("_")} + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") + + +def AgentNode( + *, + id: str, + agent_config: Any, + description: str | None = None, + timeout: float | None = None, +) -> Any: + return _t.NodeConfig_Agent( + _t.AgentNodeConfig( + id=id, + description=description, + timeout=_optional_ns(timeout), + agent_config=_coerce_nested_config(agent_config), + ) + ) + + +def MultiAgentNode(*, id: str, orchestrator: Any, description: str | None = None) -> Any: + return _t.NodeConfig_MultiAgent( + _t.MultiAgentNodeConfig( + id=id, + description=description, + orchestrator=_coerce_nested_config(orchestrator), + ) + ) + + +class Graph(_t.GraphConfig): + def __init__( + self, + *, + id: str, + nodes: list[Any], + edges: list[Any] | None = None, + sources: list[str] | None = None, + max_concurrency: int | None = None, + max_steps: int | None = None, + timeout: float | None = None, + node_timeout: float | None = None, + ) -> None: + super().__init__( + id=id, + nodes=nodes, + edges=edges or [], + sources=sources or [], + max_concurrency=max_concurrency, + max_steps=max_steps, + timeout=_optional_ns(timeout), + node_timeout=_optional_ns(node_timeout), + ) + + +class Swarm(_t.SwarmConfig): + def __init__( + self, + *, + id: str, + nodes: list[Any], + start_node_id: str, + max_steps: int | None = None, + timeout: float | None = None, + node_timeout: float | None = None, + ) -> None: + super().__init__( + id=id, + nodes=nodes, + start_node_id=start_node_id, + max_steps=max_steps, + timeout=_optional_ns(timeout), + node_timeout=_optional_ns(node_timeout), + ) + + +def BashTool(*, default_timeout: int | None = None) -> Any: + return _t.VendedTool_Bash(_t.BashToolConfig(default_timeout_s=default_timeout)) + + +def FileEditorTool(*, workspace_root: str | None = None) -> Any: + return _t.VendedTool_FileEditor(_t.FileEditorToolConfig(workspace_root=workspace_root)) + + +def HttpRequestTool(*, allowed_hosts: list[str] | None = None, max_response_bytes: int = 0) -> Any: + return _t.VendedTool_HttpRequest( + _t.HttpRequestToolConfig( + allowed_hosts=allowed_hosts or [], + max_response_bytes=max_response_bytes, + ) + ) + + +def NotebookTool(*, workspace_root: str | None = None) -> Any: + return _t.VendedTool_Notebook(_t.NotebookToolConfig(workspace_root=workspace_root)) + + +def SkillsPlugin( + *, + skills: list[str], + strict: bool = False, + max_resource_files: int | None = None, + state_key: str | None = None, +) -> Any: + return _t.VendedPlugin_Skills( + _t.SkillsPluginConfig( + skills=[_t.SkillSource(path=p) for p in skills], + strict=strict, + max_resource_files=max_resource_files, + state_key=state_key, + ) + ) + + +def ContextOffloaderPlugin( + *, + max_result_tokens: int | None = None, + preview_tokens: int | None = None, + include_retrieval_tool: bool = True, +) -> Any: + return _t.VendedPlugin_ContextOffloader( + _t.ContextOffloaderPluginConfig( + max_result_tokens=max_result_tokens, + preview_tokens=preview_tokens, + include_retrieval_tool=include_retrieval_tool, + ) + ) + + +def StdioMcpTransport( + *, + command: str, + args: list[str] | None = None, + env: dict[str, str] | None = None, + cwd: str | None = None, +) -> Any: + """Launch an MCP server as a subprocess and talk to it over stdio.""" + return _t.McpTransport_Stdio( + _t.StdioTransportConfig( + command=command, + args=args or [], + env=[_t.EnvVar(key=k, value=v) for k, v in (env or {}).items()], + cwd=cwd, + ) + ) + + +def StreamableHttpMcpTransport(*, url: str, headers: dict[str, str] | None = None) -> Any: + """Talk to a hosted MCP server over streamable HTTP.""" + return _t.McpTransport_StreamableHttp( + _t.HttpTransportConfig( + url=url, + headers=[_t.HttpHeader(name=k, value=v) for k, v in (headers or {}).items()], + ) + ) + + +def SseMcpTransport(*, url: str, headers: dict[str, str] | None = None) -> Any: + """Legacy SSE transport.""" + return _t.McpTransport_Sse( + _t.SseTransportConfig( + url=url, + headers=[_t.HttpHeader(name=k, value=v) for k, v in (headers or {}).items()], + ) + ) + + +class McpClient(_t.McpClientConfig): + """Declare an MCP client the host should open and route tools from.""" + + def __init__( + self, + *, + client_id: str, + transport: Any, + application_name: str | None = None, + application_version: str | None = None, + tasks_ttl: float | None = None, + tasks_poll_timeout: float | None = None, + elicitation_enabled: bool = False, + fail_open: bool = False, + disable_instrumentation: bool = False, + ) -> None: + tasks = None + if tasks_ttl is not None or tasks_poll_timeout is not None: + tasks = _t.TasksConfig( + ttl=_seconds_to_ns(tasks_ttl if tasks_ttl is not None else 60.0), + poll_timeout=_seconds_to_ns(tasks_poll_timeout if tasks_poll_timeout is not None else 300.0), + ) + super().__init__( + client_id=client_id, + application_name=application_name, + application_version=application_version, + transport=transport, + tasks_config=tasks, + elicitation_enabled=elicitation_enabled, + fail_open=fail_open, + disable_instrumentation=disable_instrumentation, + ) + + +class InterruptResponse(_t.RespondArgs): + """Reply to a paused agent via ``response-stream.respond``.""" + + def __init__(self, *, interrupt_id: str, response: Any) -> None: + payload = response if isinstance(response, str) else json.dumps(response) + super().__init__(interrupt_id=interrupt_id, response=payload) + + + +_ToolInput = Tool | PydanticTool | Callable[..., Any] +_ToolChoiceInput = Any # tagged ToolChoice variant or "name" shorthand or None + + +def _coerce_tool(item: _ToolInput) -> Tool | PydanticTool: + if isinstance(item, (Tool, PydanticTool)): + return item + if callable(item): + return tool(item) + raise TypeError(f"unsupported tool: {type(item).__name__}") + + +def _coerce_prompt(value: Any) -> Any: + """Coerce a string, content blocks, or pre-built value to a ``prompt-input``.""" + if isinstance(value, str): + return value + if isinstance(value, list): + return [_as_content_block(c) for c in value] + if hasattr(value, "__iter__") and not isinstance(value, (bytes, str)): + return [_as_content_block(c) for c in value] + return value + + +def _coerce_tool_choice(value: _ToolChoiceInput) -> Any: + if value is None: + return None + if isinstance(value, str): + return _t.ToolChoice_Named(value) + return value + + +class Agent: + """Strands agent. Construct once; call :meth:`invoke` or :meth:`stream_async`.""" + + def __init__( + self, + *, + model: Any = None, + messages: list[Any] | None = None, + system_prompt: Any = None, + tools: list[_ToolInput] | None = None, + agent_tools: list[Any] | None = None, + vended_tools: list[Any] | None = None, + vended_plugins: list[Any] | None = None, + mcp_clients: list[Any] | None = None, + name: str | None = None, + id: str | None = None, + description: str | None = None, + tool_executor: Any = None, + display_output: bool | None = None, + trace_attributes: list[Any] | None = None, + trace_context: Any = None, + session: Any = None, + conversation_manager: Any = None, + retry: Any = None, + structured_output_schema: str | None = None, + app_state: dict[str, Any] | None = None, + model_state: dict[str, Any] | None = None, + ) -> None: + self._tools: list[Tool | PydanticTool] = [_coerce_tool(t) for t in (tools or [])] + identity = None + if name is not None or id is not None or description is not None: + identity = _t.AgentIdentity(name=name, id=id, description=description) + + self._config = _t.AgentConfig( + model=model, + model_params=None, + messages=messages, + system_prompt=(_coerce_prompt(system_prompt) if system_prompt is not None else None), + tools=[t.to_spec() for t in self._tools] or None, + agent_tools=agent_tools, + vended_tools=vended_tools, + vended_plugins=vended_plugins, + mcp_clients=mcp_clients, + identity=identity, + tool_executor=tool_executor, + display_output=display_output, + trace_attributes=trace_attributes, + trace_context=trace_context, + session=session, + conversation_manager=conversation_manager, + retry=retry, + structured_output_schema=structured_output_schema, + app_state=json.dumps(app_state) if app_state else None, + model_state=json.dumps(model_state) if model_state else None, + ) + self._runtime: Any = None + + @property + def config(self) -> _t.AgentConfig: + return self._config + + def _ensure_runtime(self) -> Any: + if self._runtime is None: + from ._runtime import _AgentRuntime + + self._runtime = _AgentRuntime(self) + return self._runtime + + async def _ensure_runtime_async(self) -> Any: + rt = self._ensure_runtime() + await rt.async_init() + return rt + + def _lookup_tool(self, name: str) -> Tool | PydanticTool: + for t in self._tools: + if getattr(t, "name", None) == name: + return t + raise KeyError(f"no tool registered under name {name!r}") + + def _build_invoke_args( + self, + prompt: Any, + tools: list[_ToolInput] | None, + tool_choice: _ToolChoiceInput, + structured_output_schema: str | None, + ) -> _t.InvokeArgs: + extra_tools = [_coerce_tool(t).to_spec() for t in (tools or [])] or None + return _t.InvokeArgs( + input=_coerce_prompt(prompt), + tools=extra_tools, + tool_choice=_coerce_tool_choice(tool_choice), + structured_output_schema=structured_output_schema, + ) + + async def stream_async( + self, + prompt: Any, + *, + tools: list[_ToolInput] | None = None, + tool_choice: _ToolChoiceInput = None, + structured_output_schema: str | None = None, + ) -> AsyncIterator[_t.StreamEvent]: + """Yield :class:`StreamEvent` arms as the agent runs.""" + runtime = await self._ensure_runtime_async() + args = self._build_invoke_args(prompt, tools, tool_choice, structured_output_schema) + stream = await runtime.generate(args) + async for event in stream: + yield event + + async def invoke_async( + self, + prompt: Any, + *, + tools: list[_ToolInput] | None = None, + tool_choice: _ToolChoiceInput = None, + structured_output_schema: str | None = None, + ) -> AgentResult: + """Run the agent to completion and return an :class:`AgentResult`.""" + accumulator = _AgentResultAccumulator() + async for event in self.stream_async( + prompt, + tools=tools, + tool_choice=tool_choice, + structured_output_schema=structured_output_schema, + ): + accumulator.consume(event) + return accumulator.finalize(self) + + def invoke( + self, + prompt: Any, + *, + tools: list[_ToolInput] | None = None, + tool_choice: _ToolChoiceInput = None, + structured_output_schema: str | None = None, + ) -> AgentResult: + """Synchronous wrapper around :meth:`invoke_async`. + + Raises :class:`RuntimeError` if called from a running event loop. Use + :meth:`invoke_async` directly in Jupyter or async frameworks. + """ + try: + asyncio.get_running_loop() + except RuntimeError: + pass + else: + raise RuntimeError( + "Agent.invoke() cannot run inside an existing event loop. " + "Use 'await agent.invoke_async(...)' instead." + ) + return asyncio.run( + self.invoke_async( + prompt, + tools=tools, + tool_choice=tool_choice, + structured_output_schema=structured_output_schema, + ) + ) + + def cancel(self) -> None: + """Cancel the in-flight invocation. Fire-and-forget.""" + if self._runtime is not None: + self._runtime.cancel() + + async def respond(self, interrupt_id: str, response: Any) -> None: + runtime = await self._ensure_runtime_async() + payload = response if isinstance(response, str) else json.dumps(response) + await runtime.respond(_t.RespondArgs(interrupt_id=interrupt_id, response=payload)) + + async def get_messages(self) -> list[_t.Message]: + return await (await self._ensure_runtime_async()).get_messages() + + async def set_messages(self, messages: list[_t.Message]) -> None: + await (await self._ensure_runtime_async()).set_messages(messages) + + async def get_app_state(self) -> dict[str, Any]: + return await (await self._ensure_runtime_async()).get_app_state() + + async def set_app_state(self, state: dict[str, Any]) -> None: + await (await self._ensure_runtime_async()).set_app_state(state) + + async def get_model_state(self) -> dict[str, Any]: + return await (await self._ensure_runtime_async()).get_model_state() + + async def set_model_state(self, state: dict[str, Any]) -> None: + await (await self._ensure_runtime_async()).set_model_state(state) + + +class _AgentResultAccumulator: + """Folds the stream of events into the fields of an :class:`AgentResult`.""" + + def __init__(self) -> None: + self._stop: _t.StopEvent | None = None + self._last_message: _t.Message | None = None + self._interrupts: list[_t.Interrupt] = [] + + def consume(self, event: _t.StreamEvent) -> None: + if isinstance(event, _t.StreamEvent_MessageAdded): + self._last_message = event.value.message + elif isinstance(event, _t.StreamEvent_ModelMessage): + self._last_message = event.value.message + elif isinstance(event, _t.StreamEvent_Stop): + self._stop = event.value + elif isinstance(event, _t.StreamEvent_AgentResult): + self._stop = event.value.stop + elif isinstance(event, _t.StreamEvent_Interrupt): + self._interrupts.append(event.value) + + def finalize(self, agent: Agent) -> AgentResult: + stop = self._stop + last = self._last_message + if last is None: + last = _t.Message(role=_t.Role.ASSISTANT, content=[], metadata=None) + return AgentResult( + stop_reason=stop.reason if stop is not None else _t.StopReason.END_TURN, + last_message=last, + usage=stop.usage if stop is not None else None, + metrics=None, + structured_output=(json.loads(stop.structured_output) if stop and stop.structured_output else None), + interrupts=self._interrupts or None, + ) + + +_HookEventT = TypeVar("_HookEventT") +_HookCallback = Callable[[Any], Any] + + +@runtime_checkable +class HookProvider(Protocol): + """Bundle of related hook registrations.""" + + def register_hooks(self, registry: HookRegistry) -> None: ... + + +class HookRegistry: + """Register callbacks keyed by ``StreamEvent`` arm class. + + Each arm of the wire ``stream-event`` variant is a distinct Python class + (``StreamEvent_TextDelta``, ``StreamEvent_Stop``, ...). Subscribers match + by exact class. + + Callbacks for arms whose name begins with ``After`` dispatch in reverse + registration order, mirroring teardown semantics. Everything else + dispatches FIFO. + """ + + def __init__(self) -> None: + self._callbacks: dict[type, list[_HookCallback]] = {} + + def add_callback( + self, + event_type: type[_HookEventT], + callback: Callable[[_HookEventT], Any], + ) -> Callable[[], None]: + entries = self._callbacks.setdefault(event_type, []) + entry = typing.cast(_HookCallback, callback) + entries.append(entry) + + def _remove() -> None: + try: + self._callbacks[event_type].remove(entry) + except (KeyError, ValueError): + pass + + return _remove + + def add_hook(self, provider: HookProvider) -> None: + provider.register_hooks(self) + + def dispatch(self, event: Any) -> None: + callbacks = self._callbacks_for(event) + if any(inspect.iscoroutinefunction(cb) for cb in callbacks): + raise RuntimeError(f"event={type(event).__name__} | use dispatch_async for async callbacks") + for cb in callbacks: + cb(event) + + async def dispatch_async(self, event: Any) -> None: + for cb in self._callbacks_for(event): + result = cb(event) + if inspect.iscoroutine(result): + await typing.cast(Awaitable[Any], result) + + def _callbacks_for(self, event: Any) -> list[_HookCallback]: + entries = self._callbacks.get(type(event), []) + return list(reversed(entries)) if type(event).__name__.startswith("After") else list(entries) + + +class AgentResult: + """Terminal result of an agent invocation.""" + + def __init__( + self, + *, + stop_reason: _t.StopReason, + last_message: _t.Message, + invocation_state: dict[str, Any] | None = None, + traces: list[_t.AgentTrace] | None = None, + metrics: _t.AgentMetrics | None = None, + usage: _t.Usage | None = None, + structured_output: Any = None, + interrupts: list[_t.Interrupt] | None = None, + ) -> None: + self.stop_reason = stop_reason + self.last_message = last_message + self.invocation_state = invocation_state if invocation_state is not None else {} + self.traces = traces + self.metrics = metrics + self.usage = usage + self.structured_output = structured_output + self.interrupts = interrupts + + @property + def context_size(self) -> int | None: + return self.metrics.latest_context_size if self.metrics else None + + @property + def projected_context_size(self) -> int | None: + return self.metrics.projected_context_size if self.metrics else None + + def __str__(self) -> str: + """Concatenate text from TextBlock and ReasoningBlock content.""" + chunks: list[str] = [] + for block in self.last_message.content: + tag = getattr(block, "tag", None) + payload = getattr(block, "payload", None) + if tag == "text" and payload is not None: + chunks.append(payload.text) + elif tag == "reasoning" and payload is not None and payload.text: + chunks.append(payload.text) + return "\n".join(chunks) diff --git a/strands-py-wasm/src/strands/_generated.py b/strands-py-wasm/src/strands/_generated.py new file mode 100644 index 0000000000..624367c77b --- /dev/null +++ b/strands-py-wasm/src/strands/_generated.py @@ -0,0 +1,5438 @@ +"""Auto-generated by bindgen. Do not edit. + +Wire-shape Python types for the WIT world. Constructed values are accepted +directly by wasmtime-py without further marshaling — kebab-case record +attributes, ``Variant(tag, payload)`` for tagged variants, raw payloads for +untagged ones. +""" + +from __future__ import annotations + +from typing import Any, Optional, Union + +from wasmtime.component import Variant as _WitVariant +from wasmtime.component import VariantCase as _WitVariantCase + + +def ok(value: Any = None) -> _WitVariant: + """Wrap ``value`` as the ``ok`` arm of a ``result``.""" + return _WitVariant("ok", value) + + +def err(value: Any = None) -> _WitVariant: + """Wrap ``value`` as the ``err`` arm of a ``result``.""" + return _WitVariant("err", value) + + +class Error: + """A resource which represents some error information. + +The only method provided by this resource is `to-debug-string`, +which provides some human-readable information about the error. + +In the `wasi:io` package, this resource is returned through the +`wasi:io/streams/stream-error` type. + +To provide more specific error information, other interfaces may +offer functions to "downcast" this error into more specific types. For example, +errors returned from streams derived from filesystem types can be described using +the filesystem's own error-code type. This is done using the function +`wasi:filesystem/types/filesystem-error-code`, which takes a `borrow` +parameter and returns an `option`. + +The set of functions which can "downcast" an `error` into a more +concrete type is open.""" + # Wraps a wasmtime-py ResourceAny / ResourceHost handle. + # The runtime sets ._handle to the underlying resource and + # ._invoke to a callable that dispatches a method by WIT name. + + def __init__(self, handle: Any, invoke: Any = None) -> None: + self._handle = handle + self._invoke = invoke + + def to_debug_string(self) -> str: + return self._invoke('[method]error.to-debug-string', (self._handle,)) + + +class Pollable: + """`pollable` represents a single I/O event which may be ready, or not.""" + # Wraps a wasmtime-py ResourceAny / ResourceHost handle. + # The runtime sets ._handle to the underlying resource and + # ._invoke to a callable that dispatches a method by WIT name. + + def __init__(self, handle: Any, invoke: Any = None) -> None: + self._handle = handle + self._invoke = invoke + + def ready(self) -> bool: + return self._invoke('[method]pollable.ready', (self._handle,)) + + def block(self) -> None: + return self._invoke('[method]pollable.block', (self._handle,)) + + +StreamError = Any | None +"""An error for input-stream and output-stream operations.""" + +class InputStream: + """An input bytestream. + +`input-stream`s are *non-blocking* to the extent practical on underlying +platforms. I/O operations always return promptly; if fewer bytes are +promptly available than requested, they return the number of bytes promptly +available, which could even be zero. To wait for data to be available, +use the `subscribe` function to obtain a `pollable` which can be polled +for using `wasi:io/poll`.""" + # Wraps a wasmtime-py ResourceAny / ResourceHost handle. + # The runtime sets ._handle to the underlying resource and + # ._invoke to a callable that dispatches a method by WIT name. + + def __init__(self, handle: Any, invoke: Any = None) -> None: + self._handle = handle + self._invoke = invoke + + def read(self, len: int) -> Any: + return self._invoke('[method]input-stream.read', (self._handle, len,)) + + def blocking_read(self, len: int) -> Any: + return self._invoke('[method]input-stream.blocking-read', (self._handle, len,)) + + def skip(self, len: int) -> Any: + return self._invoke('[method]input-stream.skip', (self._handle, len,)) + + def blocking_skip(self, len: int) -> Any: + return self._invoke('[method]input-stream.blocking-skip', (self._handle, len,)) + + def subscribe(self) -> Any: + return self._invoke('[method]input-stream.subscribe', (self._handle,)) + + +class OutputStream: + """An output bytestream. + +`output-stream`s are *non-blocking* to the extent practical on +underlying platforms. Except where specified otherwise, I/O operations also +always return promptly, after the number of bytes that can be written +promptly, which could even be zero. To wait for the stream to be ready to +accept data, the `subscribe` function to obtain a `pollable` which can be +polled for using `wasi:io/poll`. + +Dropping an `output-stream` while there's still an active write in +progress may result in the data being lost. Before dropping the stream, +be sure to fully flush your writes.""" + # Wraps a wasmtime-py ResourceAny / ResourceHost handle. + # The runtime sets ._handle to the underlying resource and + # ._invoke to a callable that dispatches a method by WIT name. + + def __init__(self, handle: Any, invoke: Any = None) -> None: + self._handle = handle + self._invoke = invoke + + def check_write(self) -> Any: + return self._invoke('[method]output-stream.check-write', (self._handle,)) + + def write(self, contents: bytes) -> Any: + return self._invoke('[method]output-stream.write', (self._handle, contents,)) + + def blocking_write_and_flush(self, contents: bytes) -> Any: + return self._invoke('[method]output-stream.blocking-write-and-flush', (self._handle, contents,)) + + def flush(self) -> Any: + return self._invoke('[method]output-stream.flush', (self._handle,)) + + def blocking_flush(self) -> Any: + return self._invoke('[method]output-stream.blocking-flush', (self._handle,)) + + def subscribe(self) -> Any: + return self._invoke('[method]output-stream.subscribe', (self._handle,)) + + def write_zeroes(self, len: int) -> Any: + return self._invoke('[method]output-stream.write-zeroes', (self._handle, len,)) + + def blocking_write_zeroes_and_flush(self, len: int) -> Any: + return self._invoke('[method]output-stream.blocking-write-zeroes-and-flush', (self._handle, len,)) + + def splice(self, src: Any, len: int) -> Any: + return self._invoke('[method]output-stream.splice', (self._handle, src, len,)) + + def blocking_splice(self, src: Any, len: int) -> Any: + return self._invoke('[method]output-stream.blocking-splice', (self._handle, src, len,)) + + +class Datetime: + """A time and date in seconds plus nanoseconds.""" + def __init__( + self, + *, + seconds: int, + nanoseconds: int, + ) -> None: + setattr(self, 'seconds', seconds) + setattr(self, 'nanoseconds', nanoseconds) + + def __repr__(self) -> str: + return f'Datetime(seconds={getattr(self, 'seconds')!r}, nanoseconds={getattr(self, 'nanoseconds')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Datetime): + return NotImplemented + return getattr(self, 'seconds') == getattr(other, 'seconds') and getattr(self, 'nanoseconds') == getattr(other, 'nanoseconds') + + def __hash__(self) -> int: + return id(self) + + +class LogLevel(str): + """Severity level of a log entry.""" + __slots__ = () + + TRACE: 'LogLevel' + DEBUG: 'LogLevel' + INFO: 'LogLevel' + WARN: 'LogLevel' + ERROR: 'LogLevel' + +LogLevel.TRACE = LogLevel('trace') # type: ignore[attr-defined] +LogLevel.DEBUG = LogLevel('debug') # type: ignore[attr-defined] +LogLevel.INFO = LogLevel('info') # type: ignore[attr-defined] +LogLevel.WARN = LogLevel('warn') # type: ignore[attr-defined] +LogLevel.ERROR = LogLevel('error') # type: ignore[attr-defined] + + +class LogEntry: + """A single structured log entry.""" + def __init__( + self, + *, + level: LogLevel, + message: str, + context: Optional[str], + ) -> None: + setattr(self, 'level', level) + setattr(self, 'message', message) + setattr(self, 'context', context) + + def __repr__(self) -> str: + return f'LogEntry(level={getattr(self, 'level')!r}, message={getattr(self, 'message')!r}, context={getattr(self, 'context')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, LogEntry): + return NotImplemented + return getattr(self, 'level') == getattr(other, 'level') and getattr(self, 'message') == getattr(other, 'message') and getattr(self, 'context') == getattr(other, 'context') + + def __hash__(self) -> int: + return id(self) + + +class ElicitRequest: + """Request for user input.""" + def __init__( + self, + *, + client_id: str, + message: str, + request: str, + ) -> None: + setattr(self, 'client-id', client_id) + setattr(self, 'message', message) + setattr(self, 'request', request) + + @property + def client_id(self) -> str: + return getattr(self, 'client-id') + + def __repr__(self) -> str: + return f'ElicitRequest(client_id={getattr(self, 'client-id')!r}, message={getattr(self, 'message')!r}, request={getattr(self, 'request')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ElicitRequest): + return NotImplemented + return getattr(self, 'client-id') == getattr(other, 'client-id') and getattr(self, 'message') == getattr(other, 'message') and getattr(self, 'request') == getattr(other, 'request') + + def __hash__(self) -> int: + return id(self) + + +class ElicitAction(str): + """Outcome of an elicitation request.""" + __slots__ = () + + ACCEPT: 'ElicitAction' + DECLINE: 'ElicitAction' + CANCEL: 'ElicitAction' + +ElicitAction.ACCEPT = ElicitAction('accept') # type: ignore[attr-defined] +ElicitAction.DECLINE = ElicitAction('decline') # type: ignore[attr-defined] +ElicitAction.CANCEL = ElicitAction('cancel') # type: ignore[attr-defined] + + +class ElicitResponse: + """Response to an elicitation request.""" + def __init__( + self, + *, + action: ElicitAction, + content: Optional[str], + ) -> None: + setattr(self, 'action', action) + setattr(self, 'content', content) + + def __repr__(self) -> str: + return f'ElicitResponse(action={getattr(self, 'action')!r}, content={getattr(self, 'content')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ElicitResponse): + return NotImplemented + return getattr(self, 'action') == getattr(other, 'action') and getattr(self, 'content') == getattr(other, 'content') + + def __hash__(self) -> int: + return id(self) + + +class ElicitationError: + """Why an elicitation call failed.""" + pass + +class ElicitationError_UnknownClient(ElicitationError, _WitVariantCase): + """No handler registered for the given `client-id`.""" + tag = 'unknown-client' + +class ElicitationError_HandlerFailed(ElicitationError, _WitVariantCase): + """Handler raised an exception.""" + tag = 'handler-failed' + +class ElicitationError_TimedOut(ElicitationError, _WitVariantCase): + """Request timed out waiting for a human response.""" + tag = 'timed-out' + +_ElicitationError_CASES: dict[str, type] = { + 'unknown-client': ElicitationError_UnknownClient, + 'handler-failed': ElicitationError_HandlerFailed, + 'timed-out': ElicitationError_TimedOut, +} + +def _ElicitationError_lift(raw: _WitVariant) -> ElicitationError: + cls = _ElicitationError_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown ElicitationError arm: {raw.tag!r}') + return cls(raw.payload) +ElicitationError.lift = staticmethod(_ElicitationError_lift) # type: ignore[attr-defined] + +class TextBlock: + """Plain text.""" + def __init__( + self, + *, + text: str, + ) -> None: + setattr(self, 'text', text) + + def __repr__(self) -> str: + return f'TextBlock(text={getattr(self, 'text')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TextBlock): + return NotImplemented + return getattr(self, 'text') == getattr(other, 'text') + + def __hash__(self) -> int: + return id(self) + + +class S3Location: + """Object stored in Amazon S3.""" + def __init__( + self, + *, + uri: str, + bucket_owner: Optional[str], + ) -> None: + setattr(self, 'uri', uri) + setattr(self, 'bucket-owner', bucket_owner) + + @property + def bucket_owner(self) -> Optional[str]: + return getattr(self, 'bucket-owner') + + def __repr__(self) -> str: + return f'S3Location(uri={getattr(self, 'uri')!r}, bucket_owner={getattr(self, 'bucket-owner')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, S3Location): + return NotImplemented + return getattr(self, 'uri') == getattr(other, 'uri') and getattr(self, 'bucket-owner') == getattr(other, 'bucket-owner') + + def __hash__(self) -> int: + return id(self) + + +ImageSource = bytes | str | S3Location +"""Source of image bytes.""" + +class ImageBlock: + """Image attached to a message.""" + def __init__( + self, + *, + format: str, + source: ImageSource, + ) -> None: + setattr(self, 'format', format) + setattr(self, 'source', source) + + def __repr__(self) -> str: + return f'ImageBlock(format={getattr(self, 'format')!r}, source={getattr(self, 'source')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ImageBlock): + return NotImplemented + return getattr(self, 'format') == getattr(other, 'format') and getattr(self, 'source') == getattr(other, 'source') + + def __hash__(self) -> int: + return id(self) + + +VideoSource = bytes | S3Location +"""Source of video bytes.""" + +class VideoBlock: + """Video attached to a message.""" + def __init__( + self, + *, + format: str, + source: VideoSource, + ) -> None: + setattr(self, 'format', format) + setattr(self, 'source', source) + + def __repr__(self) -> str: + return f'VideoBlock(format={getattr(self, 'format')!r}, source={getattr(self, 'source')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, VideoBlock): + return NotImplemented + return getattr(self, 'format') == getattr(other, 'format') and getattr(self, 'source') == getattr(other, 'source') + + def __hash__(self) -> int: + return id(self) + + +DocumentSource = bytes | str | list[TextBlock] | S3Location +"""Source of document bytes.""" + +class DocumentCitationsConfig: + """Citation configuration attached to a document.""" + def __init__( + self, + *, + enabled: bool, + ) -> None: + setattr(self, 'enabled', enabled) + + def __repr__(self) -> str: + return f'DocumentCitationsConfig(enabled={getattr(self, 'enabled')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DocumentCitationsConfig): + return NotImplemented + return getattr(self, 'enabled') == getattr(other, 'enabled') + + def __hash__(self) -> int: + return id(self) + + +class DocumentBlock: + """Document attached to a message.""" + def __init__( + self, + *, + name: str, + format: str, + source: DocumentSource, + citations: Optional[DocumentCitationsConfig], + context: Optional[str], + ) -> None: + setattr(self, 'name', name) + setattr(self, 'format', format) + setattr(self, 'source', source) + setattr(self, 'citations', citations) + setattr(self, 'context', context) + + def __repr__(self) -> str: + return f'DocumentBlock(name={getattr(self, 'name')!r}, format={getattr(self, 'format')!r}, source={getattr(self, 'source')!r}, citations={getattr(self, 'citations')!r}, context={getattr(self, 'context')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DocumentBlock): + return NotImplemented + return getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'format') == getattr(other, 'format') and getattr(self, 'source') == getattr(other, 'source') and getattr(self, 'citations') == getattr(other, 'citations') and getattr(self, 'context') == getattr(other, 'context') + + def __hash__(self) -> int: + return id(self) + + +class ReasoningBlock: + """Model's thought process. Either plain reasoning (with an optional +signature) or an opaque redacted blob.""" + def __init__( + self, + *, + text: Optional[str], + signature: Optional[str], + redacted_content: Optional[bytes], + ) -> None: + setattr(self, 'text', text) + setattr(self, 'signature', signature) + setattr(self, 'redacted-content', redacted_content) + + @property + def redacted_content(self) -> Optional[bytes]: + return getattr(self, 'redacted-content') + + def __repr__(self) -> str: + return f'ReasoningBlock(text={getattr(self, 'text')!r}, signature={getattr(self, 'signature')!r}, redacted_content={getattr(self, 'redacted-content')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ReasoningBlock): + return NotImplemented + return getattr(self, 'text') == getattr(other, 'text') and getattr(self, 'signature') == getattr(other, 'signature') and getattr(self, 'redacted-content') == getattr(other, 'redacted-content') + + def __hash__(self) -> int: + return id(self) + + +class CacheKind(str): + """Prompt-caching kind. More arms will be added as providers surface +additional cache tiers (e.g. Anthropic's `ephemeral`).""" + __slots__ = () + + DEFAULT_CACHE: 'CacheKind' + +CacheKind.DEFAULT_CACHE = CacheKind('default-cache') # type: ignore[attr-defined] + + +class CachePointBlock: + """Marks a caching boundary in the prompt.""" + def __init__( + self, + *, + kind: CacheKind, + ) -> None: + setattr(self, 'kind', kind) + + def __repr__(self) -> str: + return f'CachePointBlock(kind={getattr(self, 'kind')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CachePointBlock): + return NotImplemented + return getattr(self, 'kind') == getattr(other, 'kind') + + def __hash__(self) -> int: + return id(self) + + +class GuardQualifier(str): + """How a piece of guard content should be evaluated.""" + __slots__ = () + + GROUNDING_SOURCE: 'GuardQualifier' + QUERY: 'GuardQualifier' + GUARD_CONTENT: 'GuardQualifier' + +GuardQualifier.GROUNDING_SOURCE = GuardQualifier('grounding-source') # type: ignore[attr-defined] +GuardQualifier.QUERY = GuardQualifier('query') # type: ignore[attr-defined] +GuardQualifier.GUARD_CONTENT = GuardQualifier('guard-content') # type: ignore[attr-defined] + + +class GuardContentText: + """Text submitted to a guardrail for evaluation.""" + def __init__( + self, + *, + qualifiers: list[GuardQualifier], + text: str, + ) -> None: + setattr(self, 'qualifiers', qualifiers) + setattr(self, 'text', text) + + def __repr__(self) -> str: + return f'GuardContentText(qualifiers={getattr(self, 'qualifiers')!r}, text={getattr(self, 'text')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, GuardContentText): + return NotImplemented + return getattr(self, 'qualifiers') == getattr(other, 'qualifiers') and getattr(self, 'text') == getattr(other, 'text') + + def __hash__(self) -> int: + return id(self) + + +class GuardContentImage: + """Image submitted to a guardrail for evaluation.""" + def __init__( + self, + *, + format: str, + bytes: bytes, + ) -> None: + setattr(self, 'format', format) + setattr(self, 'bytes', bytes) + + def __repr__(self) -> str: + return f'GuardContentImage(format={getattr(self, 'format')!r}, bytes={getattr(self, 'bytes')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, GuardContentImage): + return NotImplemented + return getattr(self, 'format') == getattr(other, 'format') and getattr(self, 'bytes') == getattr(other, 'bytes') + + def __hash__(self) -> int: + return id(self) + + +class GuardContentBlock: + """Content submitted to a guardrail for evaluation.""" + pass + +class GuardContentBlock_Text(GuardContentBlock, _WitVariantCase): + """Text guard content.""" + tag = 'text' + +class GuardContentBlock_Image(GuardContentBlock, _WitVariantCase): + """Image guard content.""" + tag = 'image' + +_GuardContentBlock_CASES: dict[str, type] = { + 'text': GuardContentBlock_Text, + 'image': GuardContentBlock_Image, +} + +def _GuardContentBlock_lift(raw: _WitVariant) -> GuardContentBlock: + cls = _GuardContentBlock_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown GuardContentBlock arm: {raw.tag!r}') + return cls(raw.payload) +GuardContentBlock.lift = staticmethod(_GuardContentBlock_lift) # type: ignore[attr-defined] + +class DocumentRange: + """Range within a source document (characters, pages, or chunks).""" + def __init__( + self, + *, + document_index: int, + start: int, + end: int, + ) -> None: + setattr(self, 'document-index', document_index) + setattr(self, 'start', start) + setattr(self, 'end', end) + + @property + def document_index(self) -> int: + return getattr(self, 'document-index') + + def __repr__(self) -> str: + return f'DocumentRange(document_index={getattr(self, 'document-index')!r}, start={getattr(self, 'start')!r}, end={getattr(self, 'end')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DocumentRange): + return NotImplemented + return getattr(self, 'document-index') == getattr(other, 'document-index') and getattr(self, 'start') == getattr(other, 'start') and getattr(self, 'end') == getattr(other, 'end') + + def __hash__(self) -> int: + return id(self) + + +class SearchResultRange: + """Range within a search result.""" + def __init__( + self, + *, + search_result_index: int, + start: int, + end: int, + ) -> None: + setattr(self, 'search-result-index', search_result_index) + setattr(self, 'start', start) + setattr(self, 'end', end) + + @property + def search_result_index(self) -> int: + return getattr(self, 'search-result-index') + + def __repr__(self) -> str: + return f'SearchResultRange(search_result_index={getattr(self, 'search-result-index')!r}, start={getattr(self, 'start')!r}, end={getattr(self, 'end')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SearchResultRange): + return NotImplemented + return getattr(self, 'search-result-index') == getattr(other, 'search-result-index') and getattr(self, 'start') == getattr(other, 'start') and getattr(self, 'end') == getattr(other, 'end') + + def __hash__(self) -> int: + return id(self) + + +class WebLocation: + """Web citation target.""" + def __init__( + self, + *, + url: str, + domain: Optional[str], + ) -> None: + setattr(self, 'url', url) + setattr(self, 'domain', domain) + + def __repr__(self) -> str: + return f'WebLocation(url={getattr(self, 'url')!r}, domain={getattr(self, 'domain')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, WebLocation): + return NotImplemented + return getattr(self, 'url') == getattr(other, 'url') and getattr(self, 'domain') == getattr(other, 'domain') + + def __hash__(self) -> int: + return id(self) + + +class CitationLocation: + """Anchor a citation points to.""" + pass + +class CitationLocation_DocumentChar(CitationLocation, _WitVariantCase): + """Character range within a document.""" + tag = 'document-char' + +class CitationLocation_DocumentPage(CitationLocation, _WitVariantCase): + """Page range within a document.""" + tag = 'document-page' + +class CitationLocation_DocumentChunk(CitationLocation, _WitVariantCase): + """Chunk range within a document.""" + tag = 'document-chunk' + +class CitationLocation_SearchResult(CitationLocation, _WitVariantCase): + """Range within a search result.""" + tag = 'search-result' + +class CitationLocation_Web(CitationLocation, _WitVariantCase): + """Web page.""" + tag = 'web' + +_CitationLocation_CASES: dict[str, type] = { + 'document-char': CitationLocation_DocumentChar, + 'document-page': CitationLocation_DocumentPage, + 'document-chunk': CitationLocation_DocumentChunk, + 'search-result': CitationLocation_SearchResult, + 'web': CitationLocation_Web, +} + +def _CitationLocation_lift(raw: _WitVariant) -> CitationLocation: + cls = _CitationLocation_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown CitationLocation arm: {raw.tag!r}') + return cls(raw.payload) +CitationLocation.lift = staticmethod(_CitationLocation_lift) # type: ignore[attr-defined] + +class CitationText: + """Text fragment from a source or a generated answer.""" + def __init__( + self, + *, + text: str, + ) -> None: + setattr(self, 'text', text) + + def __repr__(self) -> str: + return f'CitationText(text={getattr(self, 'text')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CitationText): + return NotImplemented + return getattr(self, 'text') == getattr(other, 'text') + + def __hash__(self) -> int: + return id(self) + + +class Citation: + """Link from generated content back to a source location.""" + def __init__( + self, + *, + location: CitationLocation, + source: str, + source_content: list[CitationText], + title: str, + ) -> None: + setattr(self, 'location', location) + setattr(self, 'source', source) + setattr(self, 'source-content', source_content) + setattr(self, 'title', title) + + @property + def source_content(self) -> list[CitationText]: + return getattr(self, 'source-content') + + def __repr__(self) -> str: + return f'Citation(location={getattr(self, 'location')!r}, source={getattr(self, 'source')!r}, source_content={getattr(self, 'source-content')!r}, title={getattr(self, 'title')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Citation): + return NotImplemented + return getattr(self, 'location') == getattr(other, 'location') and getattr(self, 'source') == getattr(other, 'source') and getattr(self, 'source-content') == getattr(other, 'source-content') and getattr(self, 'title') == getattr(other, 'title') + + def __hash__(self) -> int: + return id(self) + + +class CitationsBlock: + """Citations emitted by the model when citations are enabled.""" + def __init__( + self, + *, + citations: list[Citation], + content: list[CitationText], + ) -> None: + setattr(self, 'citations', citations) + setattr(self, 'content', content) + + def __repr__(self) -> str: + return f'CitationsBlock(citations={getattr(self, 'citations')!r}, content={getattr(self, 'content')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CitationsBlock): + return NotImplemented + return getattr(self, 'citations') == getattr(other, 'citations') and getattr(self, 'content') == getattr(other, 'content') + + def __hash__(self) -> int: + return id(self) + + +class ToolUseBlock: + """Model's request to call a tool.""" + def __init__( + self, + *, + name: str, + tool_use_id: str, + input: str, + reasoning_signature: Optional[str], + ) -> None: + setattr(self, 'name', name) + setattr(self, 'tool-use-id', tool_use_id) + setattr(self, 'input', input) + setattr(self, 'reasoning-signature', reasoning_signature) + + @property + def tool_use_id(self) -> str: + return getattr(self, 'tool-use-id') + + @property + def reasoning_signature(self) -> Optional[str]: + return getattr(self, 'reasoning-signature') + + def __repr__(self) -> str: + return f'ToolUseBlock(name={getattr(self, 'name')!r}, tool_use_id={getattr(self, 'tool-use-id')!r}, input={getattr(self, 'input')!r}, reasoning_signature={getattr(self, 'reasoning-signature')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ToolUseBlock): + return NotImplemented + return getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'tool-use-id') == getattr(other, 'tool-use-id') and getattr(self, 'input') == getattr(other, 'input') and getattr(self, 'reasoning-signature') == getattr(other, 'reasoning-signature') + + def __hash__(self) -> int: + return id(self) + + +class ToolResultStatus(str): + """Whether a tool invocation succeeded. Richer classification lives on `tools.tool-error`.""" + __slots__ = () + + SUCCESS: 'ToolResultStatus' + ERROR: 'ToolResultStatus' + +ToolResultStatus.SUCCESS = ToolResultStatus('success') # type: ignore[attr-defined] +ToolResultStatus.ERROR = ToolResultStatus('error') # type: ignore[attr-defined] + + +class JsonBlock: + """Structured JSON payload. Used for tool results and agent-as-tool +outputs that carry schema-validated data, not prose.""" + def __init__( + self, + *, + json: str, + ) -> None: + setattr(self, 'json', json) + + def __repr__(self) -> str: + return f'JsonBlock(json={getattr(self, 'json')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, JsonBlock): + return NotImplemented + return getattr(self, 'json') == getattr(other, 'json') + + def __hash__(self) -> int: + return id(self) + + +class ToolResultContent: + """Block valid inside `tool-result-block.content`. Narrower than `content-block`.""" + pass + +class ToolResultContent_Text(ToolResultContent, _WitVariantCase): + """Text output.""" + tag = 'text' + +class ToolResultContent_Json(ToolResultContent, _WitVariantCase): + """Structured JSON output.""" + tag = 'json' + +class ToolResultContent_Image(ToolResultContent, _WitVariantCase): + """Image output.""" + tag = 'image' + +class ToolResultContent_Video(ToolResultContent, _WitVariantCase): + """Video output.""" + tag = 'video' + +class ToolResultContent_Document(ToolResultContent, _WitVariantCase): + """Document output.""" + tag = 'document' + +_ToolResultContent_CASES: dict[str, type] = { + 'text': ToolResultContent_Text, + 'json': ToolResultContent_Json, + 'image': ToolResultContent_Image, + 'video': ToolResultContent_Video, + 'document': ToolResultContent_Document, +} + +def _ToolResultContent_lift(raw: _WitVariant) -> ToolResultContent: + cls = _ToolResultContent_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown ToolResultContent arm: {raw.tag!r}') + return cls(raw.payload) +ToolResultContent.lift = staticmethod(_ToolResultContent_lift) # type: ignore[attr-defined] + +class ToolResultBlock: + """Outcome of a tool execution.""" + def __init__( + self, + *, + tool_use_id: str, + status: ToolResultStatus, + content: list[ToolResultContent], + ) -> None: + setattr(self, 'tool-use-id', tool_use_id) + setattr(self, 'status', status) + setattr(self, 'content', content) + + @property + def tool_use_id(self) -> str: + return getattr(self, 'tool-use-id') + + def __repr__(self) -> str: + return f'ToolResultBlock(tool_use_id={getattr(self, 'tool-use-id')!r}, status={getattr(self, 'status')!r}, content={getattr(self, 'content')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ToolResultBlock): + return NotImplemented + return getattr(self, 'tool-use-id') == getattr(other, 'tool-use-id') and getattr(self, 'status') == getattr(other, 'status') and getattr(self, 'content') == getattr(other, 'content') + + def __hash__(self) -> int: + return id(self) + + +class InterruptResponseBlock: + """User response to a previously-raised interrupt. Supplied on the +next invocation to resume the paused agent.""" + def __init__( + self, + *, + interrupt_id: str, + response: str, + ) -> None: + setattr(self, 'interrupt-id', interrupt_id) + setattr(self, 'response', response) + + @property + def interrupt_id(self) -> str: + return getattr(self, 'interrupt-id') + + def __repr__(self) -> str: + return f'InterruptResponseBlock(interrupt_id={getattr(self, 'interrupt-id')!r}, response={getattr(self, 'response')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, InterruptResponseBlock): + return NotImplemented + return getattr(self, 'interrupt-id') == getattr(other, 'interrupt-id') and getattr(self, 'response') == getattr(other, 'response') + + def __hash__(self) -> int: + return id(self) + + +class ContentBlock: + """Any block that can appear inside a message.""" + pass + +class ContentBlock_Text(ContentBlock, _WitVariantCase): + """Plain text.""" + tag = 'text' + +class ContentBlock_Json(ContentBlock, _WitVariantCase): + """Structured JSON payload.""" + tag = 'json' + +class ContentBlock_ToolUse(ContentBlock, _WitVariantCase): + """Model requested a tool call.""" + tag = 'tool-use' + +class ContentBlock_ToolResult(ContentBlock, _WitVariantCase): + """Tool call completed.""" + tag = 'tool-result' + +class ContentBlock_Reasoning(ContentBlock, _WitVariantCase): + """Model reasoning.""" + tag = 'reasoning' + +class ContentBlock_CachePoint(ContentBlock, _WitVariantCase): + """Caching boundary marker.""" + tag = 'cache-point' + +class ContentBlock_GuardContent(ContentBlock, _WitVariantCase): + """Content submitted for guardrail evaluation.""" + tag = 'guard-content' + +class ContentBlock_Image(ContentBlock, _WitVariantCase): + """Image.""" + tag = 'image' + +class ContentBlock_Video(ContentBlock, _WitVariantCase): + """Video.""" + tag = 'video' + +class ContentBlock_Document(ContentBlock, _WitVariantCase): + """Document.""" + tag = 'document' + +class ContentBlock_Citations(ContentBlock, _WitVariantCase): + """Citations emitted by the model.""" + tag = 'citations' + +class ContentBlock_InterruptResponse(ContentBlock, _WitVariantCase): + """Response to a prior interrupt, supplied when resuming.""" + tag = 'interrupt-response' + +_ContentBlock_CASES: dict[str, type] = { + 'text': ContentBlock_Text, + 'json': ContentBlock_Json, + 'tool-use': ContentBlock_ToolUse, + 'tool-result': ContentBlock_ToolResult, + 'reasoning': ContentBlock_Reasoning, + 'cache-point': ContentBlock_CachePoint, + 'guard-content': ContentBlock_GuardContent, + 'image': ContentBlock_Image, + 'video': ContentBlock_Video, + 'document': ContentBlock_Document, + 'citations': ContentBlock_Citations, + 'interrupt-response': ContentBlock_InterruptResponse, +} + +def _ContentBlock_lift(raw: _WitVariant) -> ContentBlock: + cls = _ContentBlock_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown ContentBlock arm: {raw.tag!r}') + return cls(raw.payload) +ContentBlock.lift = staticmethod(_ContentBlock_lift) # type: ignore[attr-defined] + +class Role(str): + """Who a message is from.""" + __slots__ = () + + USER: 'Role' + ASSISTANT: 'Role' + +Role.USER = Role('user') # type: ignore[attr-defined] +Role.ASSISTANT = Role('assistant') # type: ignore[attr-defined] + + +class Usage: + """Token consumption for a model invocation.""" + def __init__( + self, + *, + input_tokens: int, + output_tokens: int, + total_tokens: int, + cache_read_input_tokens: Optional[int], + cache_write_input_tokens: Optional[int], + ) -> None: + setattr(self, 'input-tokens', input_tokens) + setattr(self, 'output-tokens', output_tokens) + setattr(self, 'total-tokens', total_tokens) + setattr(self, 'cache-read-input-tokens', cache_read_input_tokens) + setattr(self, 'cache-write-input-tokens', cache_write_input_tokens) + + @property + def input_tokens(self) -> int: + return getattr(self, 'input-tokens') + + @property + def output_tokens(self) -> int: + return getattr(self, 'output-tokens') + + @property + def total_tokens(self) -> int: + return getattr(self, 'total-tokens') + + @property + def cache_read_input_tokens(self) -> Optional[int]: + return getattr(self, 'cache-read-input-tokens') + + @property + def cache_write_input_tokens(self) -> Optional[int]: + return getattr(self, 'cache-write-input-tokens') + + def __repr__(self) -> str: + return f'Usage(input_tokens={getattr(self, 'input-tokens')!r}, output_tokens={getattr(self, 'output-tokens')!r}, total_tokens={getattr(self, 'total-tokens')!r}, cache_read_input_tokens={getattr(self, 'cache-read-input-tokens')!r}, cache_write_input_tokens={getattr(self, 'cache-write-input-tokens')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Usage): + return NotImplemented + return getattr(self, 'input-tokens') == getattr(other, 'input-tokens') and getattr(self, 'output-tokens') == getattr(other, 'output-tokens') and getattr(self, 'total-tokens') == getattr(other, 'total-tokens') and getattr(self, 'cache-read-input-tokens') == getattr(other, 'cache-read-input-tokens') and getattr(self, 'cache-write-input-tokens') == getattr(other, 'cache-write-input-tokens') + + def __hash__(self) -> int: + return id(self) + + +class Metrics: + """Performance metrics for a model invocation.""" + def __init__( + self, + *, + latency_ms: float, + ) -> None: + setattr(self, 'latency-ms', latency_ms) + + @property + def latency_ms(self) -> float: + return getattr(self, 'latency-ms') + + def __repr__(self) -> str: + return f'Metrics(latency_ms={getattr(self, 'latency-ms')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Metrics): + return NotImplemented + return getattr(self, 'latency-ms') == getattr(other, 'latency-ms') + + def __hash__(self) -> int: + return id(self) + + +class MessageMetadata: + """Metadata attached to a message. Not sent to model providers; persisted +alongside the message for bookkeeping.""" + def __init__( + self, + *, + usage: Optional[Usage], + metrics: Optional[Metrics], + custom: Optional[str], + ) -> None: + setattr(self, 'usage', usage) + setattr(self, 'metrics', metrics) + setattr(self, 'custom', custom) + + def __repr__(self) -> str: + return f'MessageMetadata(usage={getattr(self, 'usage')!r}, metrics={getattr(self, 'metrics')!r}, custom={getattr(self, 'custom')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, MessageMetadata): + return NotImplemented + return getattr(self, 'usage') == getattr(other, 'usage') and getattr(self, 'metrics') == getattr(other, 'metrics') and getattr(self, 'custom') == getattr(other, 'custom') + + def __hash__(self) -> int: + return id(self) + + +class Message: + """A complete message in a conversation.""" + def __init__( + self, + *, + role: Role, + content: list[ContentBlock], + metadata: Optional[MessageMetadata], + ) -> None: + setattr(self, 'role', role) + setattr(self, 'content', content) + setattr(self, 'metadata', metadata) + + def __repr__(self) -> str: + return f'Message(role={getattr(self, 'role')!r}, content={getattr(self, 'content')!r}, metadata={getattr(self, 'metadata')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Message): + return NotImplemented + return getattr(self, 'role') == getattr(other, 'role') and getattr(self, 'content') == getattr(other, 'content') and getattr(self, 'metadata') == getattr(other, 'metadata') + + def __hash__(self) -> int: + return id(self) + + +PromptInput = str | list[ContentBlock] +"""A prompt-style input: either prose or structured content. Used for +both system prompts and user input.""" + +class AnthropicConfig: + """Anthropic API model configuration.""" + def __init__( + self, + *, + model_id: Optional[str], + api_key: Optional[str], + additional_config: Optional[str], + ) -> None: + setattr(self, 'model-id', model_id) + setattr(self, 'api-key', api_key) + setattr(self, 'additional-config', additional_config) + + @property + def model_id(self) -> Optional[str]: + return getattr(self, 'model-id') + + @property + def api_key(self) -> Optional[str]: + return getattr(self, 'api-key') + + @property + def additional_config(self) -> Optional[str]: + return getattr(self, 'additional-config') + + def __repr__(self) -> str: + return f'AnthropicConfig(model_id={getattr(self, 'model-id')!r}, api_key={getattr(self, 'api-key')!r}, additional_config={getattr(self, 'additional-config')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AnthropicConfig): + return NotImplemented + return getattr(self, 'model-id') == getattr(other, 'model-id') and getattr(self, 'api-key') == getattr(other, 'api-key') and getattr(self, 'additional-config') == getattr(other, 'additional-config') + + def __hash__(self) -> int: + return id(self) + + +class BedrockConfig: + """AWS Bedrock model configuration.""" + def __init__( + self, + *, + model_id: str, + region: Optional[str], + access_key_id: Optional[str], + secret_access_key: Optional[str], + session_token: Optional[str], + additional_config: Optional[str], + ) -> None: + setattr(self, 'model-id', model_id) + setattr(self, 'region', region) + setattr(self, 'access-key-id', access_key_id) + setattr(self, 'secret-access-key', secret_access_key) + setattr(self, 'session-token', session_token) + setattr(self, 'additional-config', additional_config) + + @property + def model_id(self) -> str: + return getattr(self, 'model-id') + + @property + def access_key_id(self) -> Optional[str]: + return getattr(self, 'access-key-id') + + @property + def secret_access_key(self) -> Optional[str]: + return getattr(self, 'secret-access-key') + + @property + def session_token(self) -> Optional[str]: + return getattr(self, 'session-token') + + @property + def additional_config(self) -> Optional[str]: + return getattr(self, 'additional-config') + + def __repr__(self) -> str: + return f'BedrockConfig(model_id={getattr(self, 'model-id')!r}, region={getattr(self, 'region')!r}, access_key_id={getattr(self, 'access-key-id')!r}, secret_access_key={getattr(self, 'secret-access-key')!r}, session_token={getattr(self, 'session-token')!r}, additional_config={getattr(self, 'additional-config')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, BedrockConfig): + return NotImplemented + return getattr(self, 'model-id') == getattr(other, 'model-id') and getattr(self, 'region') == getattr(other, 'region') and getattr(self, 'access-key-id') == getattr(other, 'access-key-id') and getattr(self, 'secret-access-key') == getattr(other, 'secret-access-key') and getattr(self, 'session-token') == getattr(other, 'session-token') and getattr(self, 'additional-config') == getattr(other, 'additional-config') + + def __hash__(self) -> int: + return id(self) + + +class OpenaiConfig: + """OpenAI API model configuration.""" + def __init__( + self, + *, + model_id: Optional[str], + api_key: Optional[str], + additional_config: Optional[str], + ) -> None: + setattr(self, 'model-id', model_id) + setattr(self, 'api-key', api_key) + setattr(self, 'additional-config', additional_config) + + @property + def model_id(self) -> Optional[str]: + return getattr(self, 'model-id') + + @property + def api_key(self) -> Optional[str]: + return getattr(self, 'api-key') + + @property + def additional_config(self) -> Optional[str]: + return getattr(self, 'additional-config') + + def __repr__(self) -> str: + return f'OpenaiConfig(model_id={getattr(self, 'model-id')!r}, api_key={getattr(self, 'api-key')!r}, additional_config={getattr(self, 'additional-config')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, OpenaiConfig): + return NotImplemented + return getattr(self, 'model-id') == getattr(other, 'model-id') and getattr(self, 'api-key') == getattr(other, 'api-key') and getattr(self, 'additional-config') == getattr(other, 'additional-config') + + def __hash__(self) -> int: + return id(self) + + +class GeminiConfig: + """Google Gemini API model configuration.""" + def __init__( + self, + *, + model_id: Optional[str], + api_key: Optional[str], + additional_config: Optional[str], + ) -> None: + setattr(self, 'model-id', model_id) + setattr(self, 'api-key', api_key) + setattr(self, 'additional-config', additional_config) + + @property + def model_id(self) -> Optional[str]: + return getattr(self, 'model-id') + + @property + def api_key(self) -> Optional[str]: + return getattr(self, 'api-key') + + @property + def additional_config(self) -> Optional[str]: + return getattr(self, 'additional-config') + + def __repr__(self) -> str: + return f'GeminiConfig(model_id={getattr(self, 'model-id')!r}, api_key={getattr(self, 'api-key')!r}, additional_config={getattr(self, 'additional-config')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, GeminiConfig): + return NotImplemented + return getattr(self, 'model-id') == getattr(other, 'model-id') and getattr(self, 'api-key') == getattr(other, 'api-key') and getattr(self, 'additional-config') == getattr(other, 'additional-config') + + def __hash__(self) -> int: + return id(self) + + +class CustomModelConfig: + """Custom model provider supplied by your application.""" + def __init__( + self, + *, + provider_id: str, + model_id: Optional[str], + additional_config: Optional[str], + stateful: bool, + ) -> None: + setattr(self, 'provider-id', provider_id) + setattr(self, 'model-id', model_id) + setattr(self, 'additional-config', additional_config) + setattr(self, 'stateful', stateful) + + @property + def provider_id(self) -> str: + return getattr(self, 'provider-id') + + @property + def model_id(self) -> Optional[str]: + return getattr(self, 'model-id') + + @property + def additional_config(self) -> Optional[str]: + return getattr(self, 'additional-config') + + def __repr__(self) -> str: + return f'CustomModelConfig(provider_id={getattr(self, 'provider-id')!r}, model_id={getattr(self, 'model-id')!r}, additional_config={getattr(self, 'additional-config')!r}, stateful={getattr(self, 'stateful')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CustomModelConfig): + return NotImplemented + return getattr(self, 'provider-id') == getattr(other, 'provider-id') and getattr(self, 'model-id') == getattr(other, 'model-id') and getattr(self, 'additional-config') == getattr(other, 'additional-config') and getattr(self, 'stateful') == getattr(other, 'stateful') + + def __hash__(self) -> int: + return id(self) + + +class ModelConfig: + """Which model provider the agent should use.""" + pass + +class ModelConfig_Anthropic(ModelConfig, _WitVariantCase): + """Anthropic API.""" + tag = 'anthropic' + +class ModelConfig_Bedrock(ModelConfig, _WitVariantCase): + """AWS Bedrock.""" + tag = 'bedrock' + +class ModelConfig_Openai(ModelConfig, _WitVariantCase): + """OpenAI API.""" + tag = 'openai' + +class ModelConfig_Gemini(ModelConfig, _WitVariantCase): + """Google Gemini API.""" + tag = 'gemini' + +class ModelConfig_Custom(ModelConfig, _WitVariantCase): + """Custom provider supplied by your application. Implement the +`model-provider` interface to serve it.""" + tag = 'custom' + +_ModelConfig_CASES: dict[str, type] = { + 'anthropic': ModelConfig_Anthropic, + 'bedrock': ModelConfig_Bedrock, + 'openai': ModelConfig_Openai, + 'gemini': ModelConfig_Gemini, + 'custom': ModelConfig_Custom, +} + +def _ModelConfig_lift(raw: _WitVariant) -> ModelConfig: + cls = _ModelConfig_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown ModelConfig arm: {raw.tag!r}') + return cls(raw.payload) +ModelConfig.lift = staticmethod(_ModelConfig_lift) # type: ignore[attr-defined] + +class ModelParams: + """Sampling parameters applied to every call on the chosen provider.""" + def __init__( + self, + *, + max_tokens: Optional[int], + temperature: Optional[float], + top_p: Optional[float], + ) -> None: + setattr(self, 'max-tokens', max_tokens) + setattr(self, 'temperature', temperature) + setattr(self, 'top-p', top_p) + + @property + def max_tokens(self) -> Optional[int]: + return getattr(self, 'max-tokens') + + @property + def top_p(self) -> Optional[float]: + return getattr(self, 'top-p') + + def __repr__(self) -> str: + return f'ModelParams(max_tokens={getattr(self, 'max-tokens')!r}, temperature={getattr(self, 'temperature')!r}, top_p={getattr(self, 'top-p')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ModelParams): + return NotImplemented + return getattr(self, 'max-tokens') == getattr(other, 'max-tokens') and getattr(self, 'temperature') == getattr(other, 'temperature') and getattr(self, 'top-p') == getattr(other, 'top-p') + + def __hash__(self) -> int: + return id(self) + + +class ModelError: + """Why a model call failed. Retry logic keys off of which arm fires, so +implementations should pick the narrowest one that fits.""" + pass + +class ModelError_UnknownProvider(ModelError, _WitVariantCase): + """No provider registered for the given `provider-id`.""" + tag = 'unknown-provider' + +class ModelError_InvalidRequest(ModelError, _WitVariantCase): + """Provider refused the request due to malformed input.""" + tag = 'invalid-request' + +class ModelError_Unauthorized(ModelError, _WitVariantCase): + """Caller lacks permission (missing or expired credentials).""" + tag = 'unauthorized' + +class ModelError_Throttled(ModelError, _WitVariantCase): + """Provider returned a rate-limit error. Retry after a backoff.""" + tag = 'throttled' + +class ModelError_ServerError(ModelError, _WitVariantCase): + """Provider returned a server-side error. Retry may succeed.""" + tag = 'server-error' + +class ModelError_ContextWindowExceeded(ModelError, _WitVariantCase): + """Request exceeded the model's context window.""" + tag = 'context-window-exceeded' + +class ModelError_ContentFiltered(ModelError, _WitVariantCase): + """Content was rejected by provider safety policy.""" + tag = 'content-filtered' + +class ModelError_Transient(ModelError, _WitVariantCase): + """Transient network or transport failure. Retry may succeed.""" + tag = 'transient' + +class ModelError_Internal(ModelError, _WitVariantCase): + """Catch-all for internal failures.""" + tag = 'internal' + +_ModelError_CASES: dict[str, type] = { + 'unknown-provider': ModelError_UnknownProvider, + 'invalid-request': ModelError_InvalidRequest, + 'unauthorized': ModelError_Unauthorized, + 'throttled': ModelError_Throttled, + 'server-error': ModelError_ServerError, + 'context-window-exceeded': ModelError_ContextWindowExceeded, + 'content-filtered': ModelError_ContentFiltered, + 'transient': ModelError_Transient, + 'internal': ModelError_Internal, +} + +def _ModelError_lift(raw: _WitVariant) -> ModelError: + cls = _ModelError_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown ModelError arm: {raw.tag!r}') + return cls(raw.payload) +ModelError.lift = staticmethod(_ModelError_lift) # type: ignore[attr-defined] + +class SlidingWindowConfig: + """Sliding-window strategy: trim oldest messages once the conversation +exceeds `window-size`.""" + def __init__( + self, + *, + window_size: int, + should_truncate_results: bool, + ) -> None: + setattr(self, 'window-size', window_size) + setattr(self, 'should-truncate-results', should_truncate_results) + + @property + def window_size(self) -> int: + return getattr(self, 'window-size') + + @property + def should_truncate_results(self) -> bool: + return getattr(self, 'should-truncate-results') + + def __repr__(self) -> str: + return f'SlidingWindowConfig(window_size={getattr(self, 'window-size')!r}, should_truncate_results={getattr(self, 'should-truncate-results')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SlidingWindowConfig): + return NotImplemented + return getattr(self, 'window-size') == getattr(other, 'window-size') and getattr(self, 'should-truncate-results') == getattr(other, 'should-truncate-results') + + def __hash__(self) -> int: + return id(self) + + +class SummarizingConfig: + """Summarizing strategy: once the conversation grows, summarize older +messages into a single summary message and keep the rest verbatim.""" + def __init__( + self, + *, + summary_ratio: float, + preserve_recent_messages: int, + summarization_system_prompt: Optional[str], + summarization_model: Optional[ModelConfig], + ) -> None: + setattr(self, 'summary-ratio', summary_ratio) + setattr(self, 'preserve-recent-messages', preserve_recent_messages) + setattr(self, 'summarization-system-prompt', summarization_system_prompt) + setattr(self, 'summarization-model', summarization_model) + + @property + def summary_ratio(self) -> float: + return getattr(self, 'summary-ratio') + + @property + def preserve_recent_messages(self) -> int: + return getattr(self, 'preserve-recent-messages') + + @property + def summarization_system_prompt(self) -> Optional[str]: + return getattr(self, 'summarization-system-prompt') + + @property + def summarization_model(self) -> Optional[ModelConfig]: + return getattr(self, 'summarization-model') + + def __repr__(self) -> str: + return f'SummarizingConfig(summary_ratio={getattr(self, 'summary-ratio')!r}, preserve_recent_messages={getattr(self, 'preserve-recent-messages')!r}, summarization_system_prompt={getattr(self, 'summarization-system-prompt')!r}, summarization_model={getattr(self, 'summarization-model')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SummarizingConfig): + return NotImplemented + return getattr(self, 'summary-ratio') == getattr(other, 'summary-ratio') and getattr(self, 'preserve-recent-messages') == getattr(other, 'preserve-recent-messages') and getattr(self, 'summarization-system-prompt') == getattr(other, 'summarization-system-prompt') and getattr(self, 'summarization-model') == getattr(other, 'summarization-model') + + def __hash__(self) -> int: + return id(self) + + +class ConversationManagerConfig: + """Which conversation manager the agent uses.""" + pass + +class ConversationManagerConfig_None(ConversationManagerConfig, _WitVariantCase): + """No conversation management. History grows without bound and +context-overflow errors propagate to the caller.""" + tag = 'none' + +class ConversationManagerConfig_SlidingWindow(ConversationManagerConfig, _WitVariantCase): + """Sliding-window trimming.""" + tag = 'sliding-window' + +class ConversationManagerConfig_Summarizing(ConversationManagerConfig, _WitVariantCase): + """Summarization of older messages.""" + tag = 'summarizing' + +_ConversationManagerConfig_CASES: dict[str, type] = { + 'none': ConversationManagerConfig_None, + 'sliding-window': ConversationManagerConfig_SlidingWindow, + 'summarizing': ConversationManagerConfig_Summarizing, +} + +def _ConversationManagerConfig_lift(raw: _WitVariant) -> ConversationManagerConfig: + cls = _ConversationManagerConfig_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown ConversationManagerConfig arm: {raw.tag!r}') + return cls(raw.payload) +ConversationManagerConfig.lift = staticmethod(_ConversationManagerConfig_lift) # type: ignore[attr-defined] + +class JitterKind(str): + """How much random variation to apply to computed delays.""" + __slots__ = () + + NONE: 'JitterKind' + FULL: 'JitterKind' + EQUAL: 'JitterKind' + DECORRELATED: 'JitterKind' + +JitterKind.NONE = JitterKind('none') # type: ignore[attr-defined] +JitterKind.FULL = JitterKind('full') # type: ignore[attr-defined] +JitterKind.EQUAL = JitterKind('equal') # type: ignore[attr-defined] +JitterKind.DECORRELATED = JitterKind('decorrelated') # type: ignore[attr-defined] + + +class ConstantBackoffConfig: + """Fixed delay between attempts.""" + def __init__( + self, + *, + delay: int, + ) -> None: + setattr(self, 'delay', delay) + + def __repr__(self) -> str: + return f'ConstantBackoffConfig(delay={getattr(self, 'delay')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ConstantBackoffConfig): + return NotImplemented + return getattr(self, 'delay') == getattr(other, 'delay') + + def __hash__(self) -> int: + return id(self) + + +class LinearBackoffConfig: + """Delay grows linearly with attempt number.""" + def __init__( + self, + *, + base: int, + max: int, + jitter: JitterKind, + ) -> None: + setattr(self, 'base', base) + setattr(self, 'max', max) + setattr(self, 'jitter', jitter) + + def __repr__(self) -> str: + return f'LinearBackoffConfig(base={getattr(self, 'base')!r}, max={getattr(self, 'max')!r}, jitter={getattr(self, 'jitter')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, LinearBackoffConfig): + return NotImplemented + return getattr(self, 'base') == getattr(other, 'base') and getattr(self, 'max') == getattr(other, 'max') and getattr(self, 'jitter') == getattr(other, 'jitter') + + def __hash__(self) -> int: + return id(self) + + +class ExponentialBackoffConfig: + """Delay grows exponentially with attempt number.""" + def __init__( + self, + *, + base: int, + max: int, + factor: float, + jitter: JitterKind, + ) -> None: + setattr(self, 'base', base) + setattr(self, 'max', max) + setattr(self, 'factor', factor) + setattr(self, 'jitter', jitter) + + def __repr__(self) -> str: + return f'ExponentialBackoffConfig(base={getattr(self, 'base')!r}, max={getattr(self, 'max')!r}, factor={getattr(self, 'factor')!r}, jitter={getattr(self, 'jitter')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ExponentialBackoffConfig): + return NotImplemented + return getattr(self, 'base') == getattr(other, 'base') and getattr(self, 'max') == getattr(other, 'max') and getattr(self, 'factor') == getattr(other, 'factor') and getattr(self, 'jitter') == getattr(other, 'jitter') + + def __hash__(self) -> int: + return id(self) + + +class BackoffStrategy: + """Backoff curve applied between attempts.""" + pass + +class BackoffStrategy_Constant(BackoffStrategy, _WitVariantCase): + """Fixed delay.""" + tag = 'constant' + +class BackoffStrategy_Linear(BackoffStrategy, _WitVariantCase): + """Linear growth.""" + tag = 'linear' + +class BackoffStrategy_Exponential(BackoffStrategy, _WitVariantCase): + """Exponential growth.""" + tag = 'exponential' + +_BackoffStrategy_CASES: dict[str, type] = { + 'constant': BackoffStrategy_Constant, + 'linear': BackoffStrategy_Linear, + 'exponential': BackoffStrategy_Exponential, +} + +def _BackoffStrategy_lift(raw: _WitVariant) -> BackoffStrategy: + cls = _BackoffStrategy_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown BackoffStrategy arm: {raw.tag!r}') + return cls(raw.payload) +BackoffStrategy.lift = staticmethod(_BackoffStrategy_lift) # type: ignore[attr-defined] + +class ModelRetryStrategy: + """A single retry strategy. Default is exponential backoff, full jitter, 6 attempts.""" + def __init__( + self, + *, + max_attempts: int, + backoff: BackoffStrategy, + total_budget: Optional[int], + ) -> None: + setattr(self, 'max-attempts', max_attempts) + setattr(self, 'backoff', backoff) + setattr(self, 'total-budget', total_budget) + + @property + def max_attempts(self) -> int: + return getattr(self, 'max-attempts') + + @property + def total_budget(self) -> Optional[int]: + return getattr(self, 'total-budget') + + def __repr__(self) -> str: + return f'ModelRetryStrategy(max_attempts={getattr(self, 'max-attempts')!r}, backoff={getattr(self, 'backoff')!r}, total_budget={getattr(self, 'total-budget')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ModelRetryStrategy): + return NotImplemented + return getattr(self, 'max-attempts') == getattr(other, 'max-attempts') and getattr(self, 'backoff') == getattr(other, 'backoff') and getattr(self, 'total-budget') == getattr(other, 'total-budget') + + def __hash__(self) -> int: + return id(self) + + +class RetryConfig: + """Retry configuration attached to an agent. +Every strategy observes every failure; the first to request a delay wins. +Empty list disables retries; omitting `agent-config.retry` applies a default +single exponential strategy. Two strategies with the same `backoff` arm +surface as `agent-error::invalid-input`.""" + def __init__( + self, + *, + strategies: list[ModelRetryStrategy], + ) -> None: + setattr(self, 'strategies', strategies) + + def __repr__(self) -> str: + return f'RetryConfig(strategies={getattr(self, 'strategies')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, RetryConfig): + return NotImplemented + return getattr(self, 'strategies') == getattr(other, 'strategies') + + def __hash__(self) -> int: + return id(self) + + +class FileStorageConfig: + """Local filesystem snapshot storage.""" + def __init__( + self, + *, + base_dir: str, + ) -> None: + setattr(self, 'base-dir', base_dir) + + @property + def base_dir(self) -> str: + return getattr(self, 'base-dir') + + def __repr__(self) -> str: + return f'FileStorageConfig(base_dir={getattr(self, 'base-dir')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, FileStorageConfig): + return NotImplemented + return getattr(self, 'base-dir') == getattr(other, 'base-dir') + + def __hash__(self) -> int: + return id(self) + + +class S3StorageConfig: + """S3 snapshot storage.""" + def __init__( + self, + *, + bucket: str, + region: Optional[str], + prefix: Optional[str], + ) -> None: + setattr(self, 'bucket', bucket) + setattr(self, 'region', region) + setattr(self, 'prefix', prefix) + + def __repr__(self) -> str: + return f'S3StorageConfig(bucket={getattr(self, 'bucket')!r}, region={getattr(self, 'region')!r}, prefix={getattr(self, 'prefix')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, S3StorageConfig): + return NotImplemented + return getattr(self, 'bucket') == getattr(other, 'bucket') and getattr(self, 'region') == getattr(other, 'region') and getattr(self, 'prefix') == getattr(other, 'prefix') + + def __hash__(self) -> int: + return id(self) + + +class CustomStorageConfig: + """Reference to an application-implemented storage backend.""" + def __init__( + self, + *, + backend_id: str, + ) -> None: + setattr(self, 'backend-id', backend_id) + + @property + def backend_id(self) -> str: + return getattr(self, 'backend-id') + + def __repr__(self) -> str: + return f'CustomStorageConfig(backend_id={getattr(self, 'backend-id')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CustomStorageConfig): + return NotImplemented + return getattr(self, 'backend-id') == getattr(other, 'backend-id') + + def __hash__(self) -> int: + return id(self) + + +class StorageConfig: + """Where to persist session snapshots.""" + pass + +class StorageConfig_File(StorageConfig, _WitVariantCase): + """Local filesystem.""" + tag = 'file' + +class StorageConfig_S3(StorageConfig, _WitVariantCase): + """Amazon S3.""" + tag = 's3' + +class StorageConfig_Custom(StorageConfig, _WitVariantCase): + """Application-implemented backend.""" + tag = 'custom' + +_StorageConfig_CASES: dict[str, type] = { + 'file': StorageConfig_File, + 's3': StorageConfig_S3, + 'custom': StorageConfig_Custom, +} + +def _StorageConfig_lift(raw: _WitVariant) -> StorageConfig: + cls = _StorageConfig_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown StorageConfig arm: {raw.tag!r}') + return cls(raw.payload) +StorageConfig.lift = staticmethod(_StorageConfig_lift) # type: ignore[attr-defined] + +class SaveLatestPolicy: + """When to update the "latest" snapshot pointer. The `trigger` arm +carries the id of an application-supplied callback that decides +per-invocation.""" + pass + +class SaveLatestPolicy_Message(SaveLatestPolicy, _WitVariantCase): + """After every message added to the conversation.""" + tag = 'message' + +class SaveLatestPolicy_Invocation(SaveLatestPolicy, _WitVariantCase): + """Once per invocation, after it completes.""" + tag = 'invocation' + +class SaveLatestPolicy_Trigger(SaveLatestPolicy, _WitVariantCase): + """Each invocation consults the named `snapshot-trigger-handler`. +The id identifies which handler to invoke.""" + tag = 'trigger' + +_SaveLatestPolicy_CASES: dict[str, type] = { + 'message': SaveLatestPolicy_Message, + 'invocation': SaveLatestPolicy_Invocation, + 'trigger': SaveLatestPolicy_Trigger, +} + +def _SaveLatestPolicy_lift(raw: _WitVariant) -> SaveLatestPolicy: + cls = _SaveLatestPolicy_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown SaveLatestPolicy arm: {raw.tag!r}') + return cls(raw.payload) +SaveLatestPolicy.lift = staticmethod(_SaveLatestPolicy_lift) # type: ignore[attr-defined] + +class SessionConfig: + """Session persistence configuration attached to an agent.""" + def __init__( + self, + *, + session_id: str, + storage: StorageConfig, + save_latest: Optional[SaveLatestPolicy], + ) -> None: + setattr(self, 'session-id', session_id) + setattr(self, 'storage', storage) + setattr(self, 'save-latest', save_latest) + + @property + def session_id(self) -> str: + return getattr(self, 'session-id') + + @property + def save_latest(self) -> Optional[SaveLatestPolicy]: + return getattr(self, 'save-latest') + + def __repr__(self) -> str: + return f'SessionConfig(session_id={getattr(self, 'session-id')!r}, storage={getattr(self, 'storage')!r}, save_latest={getattr(self, 'save-latest')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SessionConfig): + return NotImplemented + return getattr(self, 'session-id') == getattr(other, 'session-id') and getattr(self, 'storage') == getattr(other, 'storage') and getattr(self, 'save-latest') == getattr(other, 'save-latest') + + def __hash__(self) -> int: + return id(self) + + +class SnapshotScope(str): + """Which kind of state a snapshot describes.""" + __slots__ = () + + AGENT: 'SnapshotScope' + MULTI_AGENT: 'SnapshotScope' + +SnapshotScope.AGENT = SnapshotScope('agent') # type: ignore[attr-defined] +SnapshotScope.MULTI_AGENT = SnapshotScope('multi-agent') # type: ignore[attr-defined] + + +class SnapshotLocation: + """Locator for a snapshot within the storage hierarchy.""" + def __init__( + self, + *, + session_id: str, + scope: SnapshotScope, + scope_id: str, + ) -> None: + setattr(self, 'session-id', session_id) + setattr(self, 'scope', scope) + setattr(self, 'scope-id', scope_id) + + @property + def session_id(self) -> str: + return getattr(self, 'session-id') + + @property + def scope_id(self) -> str: + return getattr(self, 'scope-id') + + def __repr__(self) -> str: + return f'SnapshotLocation(session_id={getattr(self, 'session-id')!r}, scope={getattr(self, 'scope')!r}, scope_id={getattr(self, 'scope-id')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SnapshotLocation): + return NotImplemented + return getattr(self, 'session-id') == getattr(other, 'session-id') and getattr(self, 'scope') == getattr(other, 'scope') and getattr(self, 'scope-id') == getattr(other, 'scope-id') + + def __hash__(self) -> int: + return id(self) + + +class SlidingWindowState: + """Sliding-window conversation manager state at snapshot time.""" + def __init__( + self, + *, + removed_message_count: int, + ) -> None: + setattr(self, 'removed-message-count', removed_message_count) + + @property + def removed_message_count(self) -> int: + return getattr(self, 'removed-message-count') + + def __repr__(self) -> str: + return f'SlidingWindowState(removed_message_count={getattr(self, 'removed-message-count')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SlidingWindowState): + return NotImplemented + return getattr(self, 'removed-message-count') == getattr(other, 'removed-message-count') + + def __hash__(self) -> int: + return id(self) + + +class SummarizingState: + """Summarizing conversation manager state at snapshot time.""" + def __init__( + self, + *, + summary_message: Optional[Message], + removed_message_count: int, + ) -> None: + setattr(self, 'summary-message', summary_message) + setattr(self, 'removed-message-count', removed_message_count) + + @property + def summary_message(self) -> Optional[Message]: + return getattr(self, 'summary-message') + + @property + def removed_message_count(self) -> int: + return getattr(self, 'removed-message-count') + + def __repr__(self) -> str: + return f'SummarizingState(summary_message={getattr(self, 'summary-message')!r}, removed_message_count={getattr(self, 'removed-message-count')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SummarizingState): + return NotImplemented + return getattr(self, 'summary-message') == getattr(other, 'summary-message') and getattr(self, 'removed-message-count') == getattr(other, 'removed-message-count') + + def __hash__(self) -> int: + return id(self) + + +class ConversationManagerState: + """Conversation manager snapshot state. Which arm is populated depends +on the conversation manager the agent was built with.""" + pass + +class ConversationManagerState_None(ConversationManagerState, _WitVariantCase): + """No conversation manager or null manager; nothing to persist.""" + tag = 'none' + +class ConversationManagerState_SlidingWindow(ConversationManagerState, _WitVariantCase): + """Sliding-window manager state.""" + tag = 'sliding-window' + +class ConversationManagerState_Summarizing(ConversationManagerState, _WitVariantCase): + """Summarizing manager state.""" + tag = 'summarizing' + +_ConversationManagerState_CASES: dict[str, type] = { + 'none': ConversationManagerState_None, + 'sliding-window': ConversationManagerState_SlidingWindow, + 'summarizing': ConversationManagerState_Summarizing, +} + +def _ConversationManagerState_lift(raw: _WitVariant) -> ConversationManagerState: + cls = _ConversationManagerState_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown ConversationManagerState arm: {raw.tag!r}') + return cls(raw.payload) +ConversationManagerState.lift = staticmethod(_ConversationManagerState_lift) # type: ignore[attr-defined] + +class RetryStrategyState: + """Retry-strategy state at snapshot time.""" + def __init__( + self, + *, + attempts_used: int, + elapsed_ms: int, + ) -> None: + setattr(self, 'attempts-used', attempts_used) + setattr(self, 'elapsed-ms', elapsed_ms) + + @property + def attempts_used(self) -> int: + return getattr(self, 'attempts-used') + + @property + def elapsed_ms(self) -> int: + return getattr(self, 'elapsed-ms') + + def __repr__(self) -> str: + return f'RetryStrategyState(attempts_used={getattr(self, 'attempts-used')!r}, elapsed_ms={getattr(self, 'elapsed-ms')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, RetryStrategyState): + return NotImplemented + return getattr(self, 'attempts-used') == getattr(other, 'attempts-used') and getattr(self, 'elapsed-ms') == getattr(other, 'elapsed-ms') + + def __hash__(self) -> int: + return id(self) + + +class PluginStateEntry: + """Named plugin state. `data` is an opaque JSON object owned by the plugin.""" + def __init__( + self, + *, + plugin_name: str, + data: str, + ) -> None: + setattr(self, 'plugin-name', plugin_name) + setattr(self, 'data', data) + + @property + def plugin_name(self) -> str: + return getattr(self, 'plugin-name') + + def __repr__(self) -> str: + return f'PluginStateEntry(plugin_name={getattr(self, 'plugin-name')!r}, data={getattr(self, 'data')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PluginStateEntry): + return NotImplemented + return getattr(self, 'plugin-name') == getattr(other, 'plugin-name') and getattr(self, 'data') == getattr(other, 'data') + + def __hash__(self) -> int: + return id(self) + + +class SnapshotData: + """Framework-owned snapshot state. All fields are optional because an +agent may not exercise every subsystem in a given run.""" + def __init__( + self, + *, + messages: list[Message], + conversation_manager: Optional[ConversationManagerState], + retry_strategy: Optional[RetryStrategyState], + model_state: Optional[str], + plugins: list[PluginStateEntry], + ) -> None: + setattr(self, 'messages', messages) + setattr(self, 'conversation-manager', conversation_manager) + setattr(self, 'retry-strategy', retry_strategy) + setattr(self, 'model-state', model_state) + setattr(self, 'plugins', plugins) + + @property + def conversation_manager(self) -> Optional[ConversationManagerState]: + return getattr(self, 'conversation-manager') + + @property + def retry_strategy(self) -> Optional[RetryStrategyState]: + return getattr(self, 'retry-strategy') + + @property + def model_state(self) -> Optional[str]: + return getattr(self, 'model-state') + + def __repr__(self) -> str: + return f'SnapshotData(messages={getattr(self, 'messages')!r}, conversation_manager={getattr(self, 'conversation-manager')!r}, retry_strategy={getattr(self, 'retry-strategy')!r}, model_state={getattr(self, 'model-state')!r}, plugins={getattr(self, 'plugins')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SnapshotData): + return NotImplemented + return getattr(self, 'messages') == getattr(other, 'messages') and getattr(self, 'conversation-manager') == getattr(other, 'conversation-manager') and getattr(self, 'retry-strategy') == getattr(other, 'retry-strategy') and getattr(self, 'model-state') == getattr(other, 'model-state') and getattr(self, 'plugins') == getattr(other, 'plugins') + + def __hash__(self) -> int: + return id(self) + + +class Snapshot: + """Point-in-time capture of agent or orchestrator state.""" + def __init__( + self, + *, + scope: SnapshotScope, + schema_version: str, + created_at: Datetime, + data: SnapshotData, + app_data: str, + ) -> None: + setattr(self, 'scope', scope) + setattr(self, 'schema-version', schema_version) + setattr(self, 'created-at', created_at) + setattr(self, 'data', data) + setattr(self, 'app-data', app_data) + + @property + def schema_version(self) -> str: + return getattr(self, 'schema-version') + + @property + def created_at(self) -> Datetime: + return getattr(self, 'created-at') + + @property + def app_data(self) -> str: + return getattr(self, 'app-data') + + def __repr__(self) -> str: + return f'Snapshot(scope={getattr(self, 'scope')!r}, schema_version={getattr(self, 'schema-version')!r}, created_at={getattr(self, 'created-at')!r}, data={getattr(self, 'data')!r}, app_data={getattr(self, 'app-data')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Snapshot): + return NotImplemented + return getattr(self, 'scope') == getattr(other, 'scope') and getattr(self, 'schema-version') == getattr(other, 'schema-version') and getattr(self, 'created-at') == getattr(other, 'created-at') and getattr(self, 'data') == getattr(other, 'data') and getattr(self, 'app-data') == getattr(other, 'app-data') + + def __hash__(self) -> int: + return id(self) + + +class SnapshotManifest: + """Metadata describing the snapshot manifest file.""" + def __init__( + self, + *, + schema_version: str, + updated_at: Datetime, + ) -> None: + setattr(self, 'schema-version', schema_version) + setattr(self, 'updated-at', updated_at) + + @property + def schema_version(self) -> str: + return getattr(self, 'schema-version') + + @property + def updated_at(self) -> Datetime: + return getattr(self, 'updated-at') + + def __repr__(self) -> str: + return f'SnapshotManifest(schema_version={getattr(self, 'schema-version')!r}, updated_at={getattr(self, 'updated-at')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SnapshotManifest): + return NotImplemented + return getattr(self, 'schema-version') == getattr(other, 'schema-version') and getattr(self, 'updated-at') == getattr(other, 'updated-at') + + def __hash__(self) -> int: + return id(self) + + +class StorageError: + """Why a snapshot operation failed.""" + pass + +class StorageError_NotFound(StorageError, _WitVariantCase): + """No snapshot or manifest at the requested location.""" + tag = 'not-found' + +class StorageError_AccessDenied(StorageError, _WitVariantCase): + """Caller lacks permission to read or write the storage.""" + tag = 'access-denied' + +class StorageError_OutOfSpace(StorageError, _WitVariantCase): + """Backing storage is full or over quota.""" + tag = 'out-of-space' + +class StorageError_Corrupt(StorageError, _WitVariantCase): + """Snapshot is malformed or cannot be deserialized.""" + tag = 'corrupt' + +class StorageError_Conflict(StorageError, _WitVariantCase): + """Concurrent writers collided; retrying may succeed.""" + tag = 'conflict' + +class StorageError_Transient(StorageError, _WitVariantCase): + """Transient I/O failure; retrying may succeed.""" + tag = 'transient' + +class StorageError_Permanent(StorageError, _WitVariantCase): + """Permanent backend failure.""" + tag = 'permanent' + +class StorageError_UnknownBackend(StorageError, _WitVariantCase): + """No custom backend registered for the given backend-id.""" + tag = 'unknown-backend' + +_StorageError_CASES: dict[str, type] = { + 'not-found': StorageError_NotFound, + 'access-denied': StorageError_AccessDenied, + 'out-of-space': StorageError_OutOfSpace, + 'corrupt': StorageError_Corrupt, + 'conflict': StorageError_Conflict, + 'transient': StorageError_Transient, + 'permanent': StorageError_Permanent, + 'unknown-backend': StorageError_UnknownBackend, +} + +def _StorageError_lift(raw: _WitVariant) -> StorageError: + cls = _StorageError_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown StorageError arm: {raw.tag!r}') + return cls(raw.payload) +StorageError.lift = staticmethod(_StorageError_lift) # type: ignore[attr-defined] + +class SaveSnapshotArgs: + """Arguments for `save-snapshot`.""" + def __init__( + self, + *, + backend_id: str, + location: SnapshotLocation, + snapshot_id: str, + is_latest: bool, + snapshot: Snapshot, + ) -> None: + setattr(self, 'backend-id', backend_id) + setattr(self, 'location', location) + setattr(self, 'snapshot-id', snapshot_id) + setattr(self, 'is-latest', is_latest) + setattr(self, 'snapshot', snapshot) + + @property + def backend_id(self) -> str: + return getattr(self, 'backend-id') + + @property + def snapshot_id(self) -> str: + return getattr(self, 'snapshot-id') + + @property + def is_latest(self) -> bool: + return getattr(self, 'is-latest') + + def __repr__(self) -> str: + return f'SaveSnapshotArgs(backend_id={getattr(self, 'backend-id')!r}, location={getattr(self, 'location')!r}, snapshot_id={getattr(self, 'snapshot-id')!r}, is_latest={getattr(self, 'is-latest')!r}, snapshot={getattr(self, 'snapshot')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SaveSnapshotArgs): + return NotImplemented + return getattr(self, 'backend-id') == getattr(other, 'backend-id') and getattr(self, 'location') == getattr(other, 'location') and getattr(self, 'snapshot-id') == getattr(other, 'snapshot-id') and getattr(self, 'is-latest') == getattr(other, 'is-latest') and getattr(self, 'snapshot') == getattr(other, 'snapshot') + + def __hash__(self) -> int: + return id(self) + + +class LoadSnapshotArgs: + """Arguments for `load-snapshot`.""" + def __init__( + self, + *, + backend_id: str, + location: SnapshotLocation, + snapshot_id: Optional[str], + ) -> None: + setattr(self, 'backend-id', backend_id) + setattr(self, 'location', location) + setattr(self, 'snapshot-id', snapshot_id) + + @property + def backend_id(self) -> str: + return getattr(self, 'backend-id') + + @property + def snapshot_id(self) -> Optional[str]: + return getattr(self, 'snapshot-id') + + def __repr__(self) -> str: + return f'LoadSnapshotArgs(backend_id={getattr(self, 'backend-id')!r}, location={getattr(self, 'location')!r}, snapshot_id={getattr(self, 'snapshot-id')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, LoadSnapshotArgs): + return NotImplemented + return getattr(self, 'backend-id') == getattr(other, 'backend-id') and getattr(self, 'location') == getattr(other, 'location') and getattr(self, 'snapshot-id') == getattr(other, 'snapshot-id') + + def __hash__(self) -> int: + return id(self) + + +class ListSnapshotIdsArgs: + """Arguments for `list-snapshot-ids`.""" + def __init__( + self, + *, + backend_id: str, + location: SnapshotLocation, + limit: Optional[int], + start_after: Optional[str], + ) -> None: + setattr(self, 'backend-id', backend_id) + setattr(self, 'location', location) + setattr(self, 'limit', limit) + setattr(self, 'start-after', start_after) + + @property + def backend_id(self) -> str: + return getattr(self, 'backend-id') + + @property + def start_after(self) -> Optional[str]: + return getattr(self, 'start-after') + + def __repr__(self) -> str: + return f'ListSnapshotIdsArgs(backend_id={getattr(self, 'backend-id')!r}, location={getattr(self, 'location')!r}, limit={getattr(self, 'limit')!r}, start_after={getattr(self, 'start-after')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ListSnapshotIdsArgs): + return NotImplemented + return getattr(self, 'backend-id') == getattr(other, 'backend-id') and getattr(self, 'location') == getattr(other, 'location') and getattr(self, 'limit') == getattr(other, 'limit') and getattr(self, 'start-after') == getattr(other, 'start-after') + + def __hash__(self) -> int: + return id(self) + + +class DeleteSessionArgs: + """Arguments for `delete-session`.""" + def __init__( + self, + *, + backend_id: str, + session_id: str, + ) -> None: + setattr(self, 'backend-id', backend_id) + setattr(self, 'session-id', session_id) + + @property + def backend_id(self) -> str: + return getattr(self, 'backend-id') + + @property + def session_id(self) -> str: + return getattr(self, 'session-id') + + def __repr__(self) -> str: + return f'DeleteSessionArgs(backend_id={getattr(self, 'backend-id')!r}, session_id={getattr(self, 'session-id')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DeleteSessionArgs): + return NotImplemented + return getattr(self, 'backend-id') == getattr(other, 'backend-id') and getattr(self, 'session-id') == getattr(other, 'session-id') + + def __hash__(self) -> int: + return id(self) + + +class ManifestArgs: + """Arguments for `load-manifest` / `save-manifest`.""" + def __init__( + self, + *, + backend_id: str, + location: SnapshotLocation, + ) -> None: + setattr(self, 'backend-id', backend_id) + setattr(self, 'location', location) + + @property + def backend_id(self) -> str: + return getattr(self, 'backend-id') + + def __repr__(self) -> str: + return f'ManifestArgs(backend_id={getattr(self, 'backend-id')!r}, location={getattr(self, 'location')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ManifestArgs): + return NotImplemented + return getattr(self, 'backend-id') == getattr(other, 'backend-id') and getattr(self, 'location') == getattr(other, 'location') + + def __hash__(self) -> int: + return id(self) + + +class SaveManifestArgs: + """Arguments for `save-manifest`.""" + def __init__( + self, + *, + backend_id: str, + location: SnapshotLocation, + manifest: SnapshotManifest, + ) -> None: + setattr(self, 'backend-id', backend_id) + setattr(self, 'location', location) + setattr(self, 'manifest', manifest) + + @property + def backend_id(self) -> str: + return getattr(self, 'backend-id') + + def __repr__(self) -> str: + return f'SaveManifestArgs(backend_id={getattr(self, 'backend-id')!r}, location={getattr(self, 'location')!r}, manifest={getattr(self, 'manifest')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SaveManifestArgs): + return NotImplemented + return getattr(self, 'backend-id') == getattr(other, 'backend-id') and getattr(self, 'location') == getattr(other, 'location') and getattr(self, 'manifest') == getattr(other, 'manifest') + + def __hash__(self) -> int: + return id(self) + + +class TriggerParams: + """Context passed to the trigger on each call.""" + def __init__( + self, + *, + trigger_id: str, + message_count: int, + last_message: Optional[Message], + ) -> None: + setattr(self, 'trigger-id', trigger_id) + setattr(self, 'message-count', message_count) + setattr(self, 'last-message', last_message) + + @property + def trigger_id(self) -> str: + return getattr(self, 'trigger-id') + + @property + def message_count(self) -> int: + return getattr(self, 'message-count') + + @property + def last_message(self) -> Optional[Message]: + return getattr(self, 'last-message') + + def __repr__(self) -> str: + return f'TriggerParams(trigger_id={getattr(self, 'trigger-id')!r}, message_count={getattr(self, 'message-count')!r}, last_message={getattr(self, 'last-message')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TriggerParams): + return NotImplemented + return getattr(self, 'trigger-id') == getattr(other, 'trigger-id') and getattr(self, 'message-count') == getattr(other, 'message-count') and getattr(self, 'last-message') == getattr(other, 'last-message') + + def __hash__(self) -> int: + return id(self) + + +class TriggerError: + """Why a trigger evaluation failed.""" + pass + +class TriggerError_Unknown(TriggerError, _WitVariantCase): + """No trigger registered for the given id.""" + tag = 'unknown' + +class TriggerError_Failed(TriggerError, _WitVariantCase): + """Trigger raised an exception.""" + tag = 'failed' + +_TriggerError_CASES: dict[str, type] = { + 'unknown': TriggerError_Unknown, + 'failed': TriggerError_Failed, +} + +def _TriggerError_lift(raw: _WitVariant) -> TriggerError: + cls = _TriggerError_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown TriggerError arm: {raw.tag!r}') + return cls(raw.payload) +TriggerError.lift = staticmethod(_TriggerError_lift) # type: ignore[attr-defined] + +class ToolSpec: + """Declaration of a tool the model can call.""" + def __init__( + self, + *, + name: str, + description: str, + input_schema: str, + ) -> None: + setattr(self, 'name', name) + setattr(self, 'description', description) + setattr(self, 'input-schema', input_schema) + + @property + def input_schema(self) -> str: + return getattr(self, 'input-schema') + + def __repr__(self) -> str: + return f'ToolSpec(name={getattr(self, 'name')!r}, description={getattr(self, 'description')!r}, input_schema={getattr(self, 'input-schema')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ToolSpec): + return NotImplemented + return getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'description') == getattr(other, 'description') and getattr(self, 'input-schema') == getattr(other, 'input-schema') + + def __hash__(self) -> int: + return id(self) + + +class AgentAsToolConfig: + """Wrap a configured agent as a tool callable by the parent agent. The +child agent is instantiated at registration time.""" + def __init__( + self, + *, + name: Optional[str], + description: Optional[str], + preserve_context: bool, + agent_config: str, + ) -> None: + setattr(self, 'name', name) + setattr(self, 'description', description) + setattr(self, 'preserve-context', preserve_context) + setattr(self, 'agent-config', agent_config) + + @property + def preserve_context(self) -> bool: + return getattr(self, 'preserve-context') + + @property + def agent_config(self) -> str: + return getattr(self, 'agent-config') + + def __repr__(self) -> str: + return f'AgentAsToolConfig(name={getattr(self, 'name')!r}, description={getattr(self, 'description')!r}, preserve_context={getattr(self, 'preserve-context')!r}, agent_config={getattr(self, 'agent-config')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AgentAsToolConfig): + return NotImplemented + return getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'description') == getattr(other, 'description') and getattr(self, 'preserve-context') == getattr(other, 'preserve-context') and getattr(self, 'agent-config') == getattr(other, 'agent-config') + + def __hash__(self) -> int: + return id(self) + + +class CallToolArgs: + """Arguments for a single tool call.""" + def __init__( + self, + *, + name: str, + input: str, + tool_use_id: str, + ) -> None: + setattr(self, 'name', name) + setattr(self, 'input', input) + setattr(self, 'tool-use-id', tool_use_id) + + @property + def tool_use_id(self) -> str: + return getattr(self, 'tool-use-id') + + def __repr__(self) -> str: + return f'CallToolArgs(name={getattr(self, 'name')!r}, input={getattr(self, 'input')!r}, tool_use_id={getattr(self, 'tool-use-id')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CallToolArgs): + return NotImplemented + return getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'input') == getattr(other, 'input') and getattr(self, 'tool-use-id') == getattr(other, 'tool-use-id') + + def __hash__(self) -> int: + return id(self) + + +class ToolChoice: + """Policy controlling whether and how the model calls tools on the next +generation step.""" + pass + +class ToolChoice_Auto(ToolChoice, _WitVariantCase): + """Model decides whether to call a tool.""" + tag = 'auto' + +class ToolChoice_Any(ToolChoice, _WitVariantCase): + """Model must call at least one tool.""" + tag = 'any' + +class ToolChoice_Named(ToolChoice, _WitVariantCase): + """Model must call the tool with this name.""" + tag = 'named' + +_ToolChoice_CASES: dict[str, type] = { + 'auto': ToolChoice_Auto, + 'any': ToolChoice_Any, + 'named': ToolChoice_Named, +} + +def _ToolChoice_lift(raw: _WitVariant) -> ToolChoice: + cls = _ToolChoice_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown ToolChoice arm: {raw.tag!r}') + return cls(raw.payload) +ToolChoice.lift = staticmethod(_ToolChoice_lift) # type: ignore[attr-defined] + +class ToolEventStream: + """Pull-based stream of tool events. Sync-WIT placeholder for +`stream`.""" + # Wraps a wasmtime-py ResourceAny / ResourceHost handle. + # The runtime sets ._handle to the underlying resource and + # ._invoke to a callable that dispatches a method by WIT name. + + def __init__(self, handle: Any, invoke: Any = None) -> None: + self._handle = handle + self._invoke = invoke + + def read(self) -> Optional[ToolStreamEvent]: + return self._invoke('[method]tool-event-stream.read', (self._handle,)) + + +class ToolError: + """Why a tool call failed.""" + pass + +class ToolError_Unknown(ToolError, _WitVariantCase): + """No tool registered under the given name.""" + tag = 'unknown' + +class ToolError_InvalidInput(ToolError, _WitVariantCase): + """Tool input didn't match the declared input schema.""" + tag = 'invalid-input' + +class ToolError_ExecutionFailed(ToolError, _WitVariantCase): + """Tool ran but returned an error result.""" + tag = 'execution-failed' + +class ToolError_TimedOut(ToolError, _WitVariantCase): + """Tool exceeded its time budget.""" + tag = 'timed-out' + +class ToolError_Cancelled(ToolError, _WitVariantCase): + """Tool was cancelled before completion.""" + tag = 'cancelled' + +class ToolError_Internal(ToolError, _WitVariantCase): + """Catch-all for internal failures.""" + tag = 'internal' + +_ToolError_CASES: dict[str, type] = { + 'unknown': ToolError_Unknown, + 'invalid-input': ToolError_InvalidInput, + 'execution-failed': ToolError_ExecutionFailed, + 'timed-out': ToolError_TimedOut, + 'cancelled': ToolError_Cancelled, + 'internal': ToolError_Internal, +} + +def _ToolError_lift(raw: _WitVariant) -> ToolError: + cls = _ToolError_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown ToolError arm: {raw.tag!r}') + return cls(raw.payload) +ToolError.lift = staticmethod(_ToolError_lift) # type: ignore[attr-defined] + +ToolStreamEvent = str | list[ToolResultContent] | ToolError +"""Incremental event emitted by a streaming tool while running.""" + +class McpConnectionState(str): + """Connection state of an MCP client.""" + __slots__ = () + + DISCONNECTED: 'McpConnectionState' + CONNECTED: 'McpConnectionState' + FAILED: 'McpConnectionState' + +McpConnectionState.DISCONNECTED = McpConnectionState('disconnected') # type: ignore[attr-defined] +McpConnectionState.CONNECTED = McpConnectionState('connected') # type: ignore[attr-defined] +McpConnectionState.FAILED = McpConnectionState('failed') # type: ignore[attr-defined] + + +class EnvVar: + """Single environment variable entry.""" + def __init__( + self, + *, + key: str, + value: str, + ) -> None: + setattr(self, 'key', key) + setattr(self, 'value', value) + + def __repr__(self) -> str: + return f'EnvVar(key={getattr(self, 'key')!r}, value={getattr(self, 'value')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, EnvVar): + return NotImplemented + return getattr(self, 'key') == getattr(other, 'key') and getattr(self, 'value') == getattr(other, 'value') + + def __hash__(self) -> int: + return id(self) + + +class StdioTransportConfig: + """STDIO transport configuration.""" + def __init__( + self, + *, + command: str, + args: list[str], + env: list[EnvVar], + cwd: Optional[str], + ) -> None: + setattr(self, 'command', command) + setattr(self, 'args', args) + setattr(self, 'env', env) + setattr(self, 'cwd', cwd) + + def __repr__(self) -> str: + return f'StdioTransportConfig(command={getattr(self, 'command')!r}, args={getattr(self, 'args')!r}, env={getattr(self, 'env')!r}, cwd={getattr(self, 'cwd')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, StdioTransportConfig): + return NotImplemented + return getattr(self, 'command') == getattr(other, 'command') and getattr(self, 'args') == getattr(other, 'args') and getattr(self, 'env') == getattr(other, 'env') and getattr(self, 'cwd') == getattr(other, 'cwd') + + def __hash__(self) -> int: + return id(self) + + +class HttpHeader: + """Single HTTP header entry.""" + def __init__( + self, + *, + name: str, + value: str, + ) -> None: + setattr(self, 'name', name) + setattr(self, 'value', value) + + def __repr__(self) -> str: + return f'HttpHeader(name={getattr(self, 'name')!r}, value={getattr(self, 'value')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, HttpHeader): + return NotImplemented + return getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'value') == getattr(other, 'value') + + def __hash__(self) -> int: + return id(self) + + +class HttpTransportConfig: + """HTTP transport configuration.""" + def __init__( + self, + *, + url: str, + headers: list[HttpHeader], + ) -> None: + setattr(self, 'url', url) + setattr(self, 'headers', headers) + + def __repr__(self) -> str: + return f'HttpTransportConfig(url={getattr(self, 'url')!r}, headers={getattr(self, 'headers')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, HttpTransportConfig): + return NotImplemented + return getattr(self, 'url') == getattr(other, 'url') and getattr(self, 'headers') == getattr(other, 'headers') + + def __hash__(self) -> int: + return id(self) + + +class SseTransportConfig: + """SSE transport configuration.""" + def __init__( + self, + *, + url: str, + headers: list[HttpHeader], + ) -> None: + setattr(self, 'url', url) + setattr(self, 'headers', headers) + + def __repr__(self) -> str: + return f'SseTransportConfig(url={getattr(self, 'url')!r}, headers={getattr(self, 'headers')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SseTransportConfig): + return NotImplemented + return getattr(self, 'url') == getattr(other, 'url') and getattr(self, 'headers') == getattr(other, 'headers') + + def __hash__(self) -> int: + return id(self) + + +class McpTransport: + """How the client talks to the MCP server.""" + pass + +class McpTransport_Stdio(McpTransport, _WitVariantCase): + """STDIO transport. Spawn a local process and talk via pipes.""" + tag = 'stdio' + +class McpTransport_StreamableHttp(McpTransport, _WitVariantCase): + """Streamable HTTP transport, per the current MCP specification.""" + tag = 'streamable-http' + +class McpTransport_Sse(McpTransport, _WitVariantCase): + """Legacy Server-Sent Events transport. Retained for older servers.""" + tag = 'sse' + +_McpTransport_CASES: dict[str, type] = { + 'stdio': McpTransport_Stdio, + 'streamable-http': McpTransport_StreamableHttp, + 'sse': McpTransport_Sse, +} + +def _McpTransport_lift(raw: _WitVariant) -> McpTransport: + cls = _McpTransport_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown McpTransport arm: {raw.tag!r}') + return cls(raw.payload) +McpTransport.lift = staticmethod(_McpTransport_lift) # type: ignore[attr-defined] + +class TasksConfig: + """Task-augmented tool execution. Enables long-running tools with +progress tracking. Experimental in the MCP specification.""" + def __init__( + self, + *, + ttl: int, + poll_timeout: int, + ) -> None: + setattr(self, 'ttl', ttl) + setattr(self, 'poll-timeout', poll_timeout) + + @property + def poll_timeout(self) -> int: + return getattr(self, 'poll-timeout') + + def __repr__(self) -> str: + return f'TasksConfig(ttl={getattr(self, 'ttl')!r}, poll_timeout={getattr(self, 'poll-timeout')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TasksConfig): + return NotImplemented + return getattr(self, 'ttl') == getattr(other, 'ttl') and getattr(self, 'poll-timeout') == getattr(other, 'poll-timeout') + + def __hash__(self) -> int: + return id(self) + + +class McpClientConfig: + """MCP client configuration.""" + def __init__( + self, + *, + client_id: str, + application_name: Optional[str], + application_version: Optional[str], + transport: McpTransport, + tasks_config: Optional[TasksConfig], + elicitation_enabled: bool, + fail_open: bool, + disable_instrumentation: bool, + ) -> None: + setattr(self, 'client-id', client_id) + setattr(self, 'application-name', application_name) + setattr(self, 'application-version', application_version) + setattr(self, 'transport', transport) + setattr(self, 'tasks-config', tasks_config) + setattr(self, 'elicitation-enabled', elicitation_enabled) + setattr(self, 'fail-open', fail_open) + setattr(self, 'disable-instrumentation', disable_instrumentation) + + @property + def client_id(self) -> str: + return getattr(self, 'client-id') + + @property + def application_name(self) -> Optional[str]: + return getattr(self, 'application-name') + + @property + def application_version(self) -> Optional[str]: + return getattr(self, 'application-version') + + @property + def tasks_config(self) -> Optional[TasksConfig]: + return getattr(self, 'tasks-config') + + @property + def elicitation_enabled(self) -> bool: + return getattr(self, 'elicitation-enabled') + + @property + def fail_open(self) -> bool: + return getattr(self, 'fail-open') + + @property + def disable_instrumentation(self) -> bool: + return getattr(self, 'disable-instrumentation') + + def __repr__(self) -> str: + return f'McpClientConfig(client_id={getattr(self, 'client-id')!r}, application_name={getattr(self, 'application-name')!r}, application_version={getattr(self, 'application-version')!r}, transport={getattr(self, 'transport')!r}, tasks_config={getattr(self, 'tasks-config')!r}, elicitation_enabled={getattr(self, 'elicitation-enabled')!r}, fail_open={getattr(self, 'fail-open')!r}, disable_instrumentation={getattr(self, 'disable-instrumentation')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, McpClientConfig): + return NotImplemented + return getattr(self, 'client-id') == getattr(other, 'client-id') and getattr(self, 'application-name') == getattr(other, 'application-name') and getattr(self, 'application-version') == getattr(other, 'application-version') and getattr(self, 'transport') == getattr(other, 'transport') and getattr(self, 'tasks-config') == getattr(other, 'tasks-config') and getattr(self, 'elicitation-enabled') == getattr(other, 'elicitation-enabled') and getattr(self, 'fail-open') == getattr(other, 'fail-open') and getattr(self, 'disable-instrumentation') == getattr(other, 'disable-instrumentation') + + def __hash__(self) -> int: + return id(self) + + +class Interrupt: + """Human-in-the-loop interrupt raised by a tool or hook.""" + def __init__( + self, + *, + id: str, + name: str, + reason: Optional[str], + ) -> None: + setattr(self, 'id', id) + setattr(self, 'name', name) + setattr(self, 'reason', reason) + + def __repr__(self) -> str: + return f'Interrupt(id={getattr(self, 'id')!r}, name={getattr(self, 'name')!r}, reason={getattr(self, 'reason')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Interrupt): + return NotImplemented + return getattr(self, 'id') == getattr(other, 'id') and getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'reason') == getattr(other, 'reason') + + def __hash__(self) -> int: + return id(self) + + +class StopReason(str): + """Why the model stopped generating.""" + __slots__ = () + + END_TURN: 'StopReason' + TOOL_USE: 'StopReason' + MAX_TOKENS: 'StopReason' + ERROR: 'StopReason' + CONTENT_FILTERED: 'StopReason' + GUARDRAIL_INTERVENED: 'StopReason' + STOP_SEQUENCE: 'StopReason' + MODEL_CONTEXT_WINDOW_EXCEEDED: 'StopReason' + CANCELLED: 'StopReason' + +StopReason.END_TURN = StopReason('end-turn') # type: ignore[attr-defined] +StopReason.TOOL_USE = StopReason('tool-use') # type: ignore[attr-defined] +StopReason.MAX_TOKENS = StopReason('max-tokens') # type: ignore[attr-defined] +StopReason.ERROR = StopReason('error') # type: ignore[attr-defined] +StopReason.CONTENT_FILTERED = StopReason('content-filtered') # type: ignore[attr-defined] +StopReason.GUARDRAIL_INTERVENED = StopReason('guardrail-intervened') # type: ignore[attr-defined] +StopReason.STOP_SEQUENCE = StopReason('stop-sequence') # type: ignore[attr-defined] +StopReason.MODEL_CONTEXT_WINDOW_EXCEEDED = StopReason('model-context-window-exceeded') # type: ignore[attr-defined] +StopReason.CANCELLED = StopReason('cancelled') # type: ignore[attr-defined] + + +class MetadataEvent: + """Usage and metrics accumulated so far.""" + def __init__( + self, + *, + usage: Optional[Usage], + metrics: Optional[Metrics], + ) -> None: + setattr(self, 'usage', usage) + setattr(self, 'metrics', metrics) + + def __repr__(self) -> str: + return f'MetadataEvent(usage={getattr(self, 'usage')!r}, metrics={getattr(self, 'metrics')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, MetadataEvent): + return NotImplemented + return getattr(self, 'usage') == getattr(other, 'usage') and getattr(self, 'metrics') == getattr(other, 'metrics') + + def __hash__(self) -> int: + return id(self) + + +class TraceMetadataEntry: + """Single key-value pair attached to a trace. Values are string-typed +to keep traces compact; structured payloads belong on `message`.""" + def __init__( + self, + *, + key: str, + value: str, + ) -> None: + setattr(self, 'key', key) + setattr(self, 'value', value) + + def __repr__(self) -> str: + return f'TraceMetadataEntry(key={getattr(self, 'key')!r}, value={getattr(self, 'value')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TraceMetadataEntry): + return NotImplemented + return getattr(self, 'key') == getattr(other, 'key') and getattr(self, 'value') == getattr(other, 'value') + + def __hash__(self) -> int: + return id(self) + + +class AgentTrace: + """In-memory trace node. Returned flat; reconstruct the tree via `parent-id`.""" + def __init__( + self, + *, + id: str, + name: str, + parent_id: Optional[str], + start_time_ms: int, + end_time_ms: Optional[int], + duration_ms: int, + metadata: list[TraceMetadataEntry], + message: Optional[Message], + ) -> None: + setattr(self, 'id', id) + setattr(self, 'name', name) + setattr(self, 'parent-id', parent_id) + setattr(self, 'start-time-ms', start_time_ms) + setattr(self, 'end-time-ms', end_time_ms) + setattr(self, 'duration-ms', duration_ms) + setattr(self, 'metadata', metadata) + setattr(self, 'message', message) + + @property + def parent_id(self) -> Optional[str]: + return getattr(self, 'parent-id') + + @property + def start_time_ms(self) -> int: + return getattr(self, 'start-time-ms') + + @property + def end_time_ms(self) -> Optional[int]: + return getattr(self, 'end-time-ms') + + @property + def duration_ms(self) -> int: + return getattr(self, 'duration-ms') + + def __repr__(self) -> str: + return f'AgentTrace(id={getattr(self, 'id')!r}, name={getattr(self, 'name')!r}, parent_id={getattr(self, 'parent-id')!r}, start_time_ms={getattr(self, 'start-time-ms')!r}, end_time_ms={getattr(self, 'end-time-ms')!r}, duration_ms={getattr(self, 'duration-ms')!r}, metadata={getattr(self, 'metadata')!r}, message={getattr(self, 'message')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AgentTrace): + return NotImplemented + return getattr(self, 'id') == getattr(other, 'id') and getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'parent-id') == getattr(other, 'parent-id') and getattr(self, 'start-time-ms') == getattr(other, 'start-time-ms') and getattr(self, 'end-time-ms') == getattr(other, 'end-time-ms') and getattr(self, 'duration-ms') == getattr(other, 'duration-ms') and getattr(self, 'metadata') == getattr(other, 'metadata') and getattr(self, 'message') == getattr(other, 'message') + + def __hash__(self) -> int: + return id(self) + + +class ToolMetrics: + """Per-tool execution metrics keyed by tool name in `agent-metrics`.""" + def __init__( + self, + *, + tool_name: str, + call_count: int, + success_count: int, + error_count: int, + total_time_ms: int, + ) -> None: + setattr(self, 'tool-name', tool_name) + setattr(self, 'call-count', call_count) + setattr(self, 'success-count', success_count) + setattr(self, 'error-count', error_count) + setattr(self, 'total-time-ms', total_time_ms) + + @property + def tool_name(self) -> str: + return getattr(self, 'tool-name') + + @property + def call_count(self) -> int: + return getattr(self, 'call-count') + + @property + def success_count(self) -> int: + return getattr(self, 'success-count') + + @property + def error_count(self) -> int: + return getattr(self, 'error-count') + + @property + def total_time_ms(self) -> int: + return getattr(self, 'total-time-ms') + + def __repr__(self) -> str: + return f'ToolMetrics(tool_name={getattr(self, 'tool-name')!r}, call_count={getattr(self, 'call-count')!r}, success_count={getattr(self, 'success-count')!r}, error_count={getattr(self, 'error-count')!r}, total_time_ms={getattr(self, 'total-time-ms')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ToolMetrics): + return NotImplemented + return getattr(self, 'tool-name') == getattr(other, 'tool-name') and getattr(self, 'call-count') == getattr(other, 'call-count') and getattr(self, 'success-count') == getattr(other, 'success-count') and getattr(self, 'error-count') == getattr(other, 'error-count') and getattr(self, 'total-time-ms') == getattr(other, 'total-time-ms') + + def __hash__(self) -> int: + return id(self) + + +class InvocationMetrics: + """Per-invocation metrics. Cycles are flattened into `agent-metrics.cycles` +and linked back via `invocation-id`.""" + def __init__( + self, + *, + invocation_id: str, + usage: Usage, + ) -> None: + setattr(self, 'invocation-id', invocation_id) + setattr(self, 'usage', usage) + + @property + def invocation_id(self) -> str: + return getattr(self, 'invocation-id') + + def __repr__(self) -> str: + return f'InvocationMetrics(invocation_id={getattr(self, 'invocation-id')!r}, usage={getattr(self, 'usage')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, InvocationMetrics): + return NotImplemented + return getattr(self, 'invocation-id') == getattr(other, 'invocation-id') and getattr(self, 'usage') == getattr(other, 'usage') + + def __hash__(self) -> int: + return id(self) + + +class AgentLoopMetrics: + """Per-cycle usage tracking.""" + def __init__( + self, + *, + cycle_id: str, + invocation_id: str, + duration_ms: int, + usage: Usage, + ) -> None: + setattr(self, 'cycle-id', cycle_id) + setattr(self, 'invocation-id', invocation_id) + setattr(self, 'duration-ms', duration_ms) + setattr(self, 'usage', usage) + + @property + def cycle_id(self) -> str: + return getattr(self, 'cycle-id') + + @property + def invocation_id(self) -> str: + return getattr(self, 'invocation-id') + + @property + def duration_ms(self) -> int: + return getattr(self, 'duration-ms') + + def __repr__(self) -> str: + return f'AgentLoopMetrics(cycle_id={getattr(self, 'cycle-id')!r}, invocation_id={getattr(self, 'invocation-id')!r}, duration_ms={getattr(self, 'duration-ms')!r}, usage={getattr(self, 'usage')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AgentLoopMetrics): + return NotImplemented + return getattr(self, 'cycle-id') == getattr(other, 'cycle-id') and getattr(self, 'invocation-id') == getattr(other, 'invocation-id') and getattr(self, 'duration-ms') == getattr(other, 'duration-ms') and getattr(self, 'usage') == getattr(other, 'usage') + + def __hash__(self) -> int: + return id(self) + + +class AgentMetrics: + """Snapshot of agent metrics. Returned by `agent.get-metrics`.""" + def __init__( + self, + *, + cycle_count: int, + accumulated_usage: Usage, + accumulated_metrics: Metrics, + invocations: list[InvocationMetrics], + cycles: list[AgentLoopMetrics], + tool_metrics: list[ToolMetrics], + latest_context_size: Optional[int], + projected_context_size: Optional[int], + ) -> None: + setattr(self, 'cycle-count', cycle_count) + setattr(self, 'accumulated-usage', accumulated_usage) + setattr(self, 'accumulated-metrics', accumulated_metrics) + setattr(self, 'invocations', invocations) + setattr(self, 'cycles', cycles) + setattr(self, 'tool-metrics', tool_metrics) + setattr(self, 'latest-context-size', latest_context_size) + setattr(self, 'projected-context-size', projected_context_size) + + @property + def cycle_count(self) -> int: + return getattr(self, 'cycle-count') + + @property + def accumulated_usage(self) -> Usage: + return getattr(self, 'accumulated-usage') + + @property + def accumulated_metrics(self) -> Metrics: + return getattr(self, 'accumulated-metrics') + + @property + def tool_metrics(self) -> list[ToolMetrics]: + return getattr(self, 'tool-metrics') + + @property + def latest_context_size(self) -> Optional[int]: + return getattr(self, 'latest-context-size') + + @property + def projected_context_size(self) -> Optional[int]: + return getattr(self, 'projected-context-size') + + def __repr__(self) -> str: + return f'AgentMetrics(cycle_count={getattr(self, 'cycle-count')!r}, accumulated_usage={getattr(self, 'accumulated-usage')!r}, accumulated_metrics={getattr(self, 'accumulated-metrics')!r}, invocations={getattr(self, 'invocations')!r}, cycles={getattr(self, 'cycles')!r}, tool_metrics={getattr(self, 'tool-metrics')!r}, latest_context_size={getattr(self, 'latest-context-size')!r}, projected_context_size={getattr(self, 'projected-context-size')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AgentMetrics): + return NotImplemented + return getattr(self, 'cycle-count') == getattr(other, 'cycle-count') and getattr(self, 'accumulated-usage') == getattr(other, 'accumulated-usage') and getattr(self, 'accumulated-metrics') == getattr(other, 'accumulated-metrics') and getattr(self, 'invocations') == getattr(other, 'invocations') and getattr(self, 'cycles') == getattr(other, 'cycles') and getattr(self, 'tool-metrics') == getattr(other, 'tool-metrics') and getattr(self, 'latest-context-size') == getattr(other, 'latest-context-size') and getattr(self, 'projected-context-size') == getattr(other, 'projected-context-size') + + def __hash__(self) -> int: + return id(self) + + +class ToolUseData: + """Mutable tool-use descriptor. `before-tool-call` hooks may rewrite fields.""" + def __init__( + self, + *, + name: str, + tool_use_id: str, + input: str, + ) -> None: + setattr(self, 'name', name) + setattr(self, 'tool-use-id', tool_use_id) + setattr(self, 'input', input) + + @property + def tool_use_id(self) -> str: + return getattr(self, 'tool-use-id') + + def __repr__(self) -> str: + return f'ToolUseData(name={getattr(self, 'name')!r}, tool_use_id={getattr(self, 'tool-use-id')!r}, input={getattr(self, 'input')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ToolUseData): + return NotImplemented + return getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'tool-use-id') == getattr(other, 'tool-use-id') and getattr(self, 'input') == getattr(other, 'input') + + def __hash__(self) -> int: + return id(self) + + +class HookRedaction: + """Redaction information when guardrails block content.""" + def __init__( + self, + *, + user_message: str, + ) -> None: + setattr(self, 'user-message', user_message) + + @property + def user_message(self) -> str: + return getattr(self, 'user-message') + + def __repr__(self) -> str: + return f'HookRedaction(user_message={getattr(self, 'user-message')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, HookRedaction): + return NotImplemented + return getattr(self, 'user-message') == getattr(other, 'user-message') + + def __hash__(self) -> int: + return id(self) + + +class ModelStopData: + """Model response surfaced on `after-model-call`.""" + def __init__( + self, + *, + message: Message, + stop_reason: StopReason, + redaction: Optional[HookRedaction], + ) -> None: + setattr(self, 'message', message) + setattr(self, 'stop-reason', stop_reason) + setattr(self, 'redaction', redaction) + + @property + def stop_reason(self) -> StopReason: + return getattr(self, 'stop-reason') + + def __repr__(self) -> str: + return f'ModelStopData(message={getattr(self, 'message')!r}, stop_reason={getattr(self, 'stop-reason')!r}, redaction={getattr(self, 'redaction')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ModelStopData): + return NotImplemented + return getattr(self, 'message') == getattr(other, 'message') and getattr(self, 'stop-reason') == getattr(other, 'stop-reason') and getattr(self, 'redaction') == getattr(other, 'redaction') + + def __hash__(self) -> int: + return id(self) + + +class BeforeInvocationData: + """Payload for `before-invocation`.""" + def __init__( + self, + *, + invocation_state: str, + ) -> None: + setattr(self, 'invocation-state', invocation_state) + + @property + def invocation_state(self) -> str: + return getattr(self, 'invocation-state') + + def __repr__(self) -> str: + return f'BeforeInvocationData(invocation_state={getattr(self, 'invocation-state')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, BeforeInvocationData): + return NotImplemented + return getattr(self, 'invocation-state') == getattr(other, 'invocation-state') + + def __hash__(self) -> int: + return id(self) + + +class AfterInvocationData: + """Payload for `after-invocation`.""" + def __init__( + self, + *, + invocation_state: str, + ) -> None: + setattr(self, 'invocation-state', invocation_state) + + @property + def invocation_state(self) -> str: + return getattr(self, 'invocation-state') + + def __repr__(self) -> str: + return f'AfterInvocationData(invocation_state={getattr(self, 'invocation-state')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AfterInvocationData): + return NotImplemented + return getattr(self, 'invocation-state') == getattr(other, 'invocation-state') + + def __hash__(self) -> int: + return id(self) + + +class MessageAddedData: + """Payload for `message-added`.""" + def __init__( + self, + *, + message: Message, + ) -> None: + setattr(self, 'message', message) + + def __repr__(self) -> str: + return f'MessageAddedData(message={getattr(self, 'message')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, MessageAddedData): + return NotImplemented + return getattr(self, 'message') == getattr(other, 'message') + + def __hash__(self) -> int: + return id(self) + + +class BeforeModelCallData: + """Payload for `before-model-call`.""" + def __init__( + self, + *, + projected_input_tokens: Optional[int], + ) -> None: + setattr(self, 'projected-input-tokens', projected_input_tokens) + + @property + def projected_input_tokens(self) -> Optional[int]: + return getattr(self, 'projected-input-tokens') + + def __repr__(self) -> str: + return f'BeforeModelCallData(projected_input_tokens={getattr(self, 'projected-input-tokens')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, BeforeModelCallData): + return NotImplemented + return getattr(self, 'projected-input-tokens') == getattr(other, 'projected-input-tokens') + + def __hash__(self) -> int: + return id(self) + + +class AfterModelCallData: + """Payload for `after-model-call`.""" + def __init__( + self, + *, + attempt_count: int, + stop_data: Optional[ModelStopData], + error: Optional[ModelError], + ) -> None: + setattr(self, 'attempt-count', attempt_count) + setattr(self, 'stop-data', stop_data) + setattr(self, 'error', error) + + @property + def attempt_count(self) -> int: + return getattr(self, 'attempt-count') + + @property + def stop_data(self) -> Optional[ModelStopData]: + return getattr(self, 'stop-data') + + def __repr__(self) -> str: + return f'AfterModelCallData(attempt_count={getattr(self, 'attempt-count')!r}, stop_data={getattr(self, 'stop-data')!r}, error={getattr(self, 'error')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AfterModelCallData): + return NotImplemented + return getattr(self, 'attempt-count') == getattr(other, 'attempt-count') and getattr(self, 'stop-data') == getattr(other, 'stop-data') and getattr(self, 'error') == getattr(other, 'error') + + def __hash__(self) -> int: + return id(self) + + +class BeforeToolCallData: + """Payload for `before-tool-call`.""" + def __init__( + self, + *, + tool_use: ToolUseData, + ) -> None: + setattr(self, 'tool-use', tool_use) + + @property + def tool_use(self) -> ToolUseData: + return getattr(self, 'tool-use') + + def __repr__(self) -> str: + return f'BeforeToolCallData(tool_use={getattr(self, 'tool-use')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, BeforeToolCallData): + return NotImplemented + return getattr(self, 'tool-use') == getattr(other, 'tool-use') + + def __hash__(self) -> int: + return id(self) + + +class AfterToolCallData: + """Payload for `after-tool-call`.""" + def __init__( + self, + *, + tool_use: ToolUseData, + tool_result: ToolResultBlock, + error: Optional[ToolError], + ) -> None: + setattr(self, 'tool-use', tool_use) + setattr(self, 'tool-result', tool_result) + setattr(self, 'error', error) + + @property + def tool_use(self) -> ToolUseData: + return getattr(self, 'tool-use') + + @property + def tool_result(self) -> ToolResultBlock: + return getattr(self, 'tool-result') + + def __repr__(self) -> str: + return f'AfterToolCallData(tool_use={getattr(self, 'tool-use')!r}, tool_result={getattr(self, 'tool-result')!r}, error={getattr(self, 'error')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AfterToolCallData): + return NotImplemented + return getattr(self, 'tool-use') == getattr(other, 'tool-use') and getattr(self, 'tool-result') == getattr(other, 'tool-result') and getattr(self, 'error') == getattr(other, 'error') + + def __hash__(self) -> int: + return id(self) + + +class ToolsBatchData: + """Payload for `before-tools` / `after-tools`.""" + def __init__( + self, + *, + message: Message, + ) -> None: + setattr(self, 'message', message) + + def __repr__(self) -> str: + return f'ToolsBatchData(message={getattr(self, 'message')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ToolsBatchData): + return NotImplemented + return getattr(self, 'message') == getattr(other, 'message') + + def __hash__(self) -> int: + return id(self) + + +class ContentBlockData: + """Payload for `content-block`.""" + def __init__( + self, + *, + content_block: ContentBlock, + ) -> None: + setattr(self, 'content-block', content_block) + + @property + def content_block(self) -> ContentBlock: + return getattr(self, 'content-block') + + def __repr__(self) -> str: + return f'ContentBlockData(content_block={getattr(self, 'content-block')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ContentBlockData): + return NotImplemented + return getattr(self, 'content-block') == getattr(other, 'content-block') + + def __hash__(self) -> int: + return id(self) + + +class ModelMessageData: + """Payload for `model-message`.""" + def __init__( + self, + *, + message: Message, + stop_reason: StopReason, + ) -> None: + setattr(self, 'message', message) + setattr(self, 'stop-reason', stop_reason) + + @property + def stop_reason(self) -> StopReason: + return getattr(self, 'stop-reason') + + def __repr__(self) -> str: + return f'ModelMessageData(message={getattr(self, 'message')!r}, stop_reason={getattr(self, 'stop-reason')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ModelMessageData): + return NotImplemented + return getattr(self, 'message') == getattr(other, 'message') and getattr(self, 'stop-reason') == getattr(other, 'stop-reason') + + def __hash__(self) -> int: + return id(self) + + +class ToolResultData: + """Payload for `tool-result-hook`.""" + def __init__( + self, + *, + tool_result: ToolResultBlock, + ) -> None: + setattr(self, 'tool-result', tool_result) + + @property + def tool_result(self) -> ToolResultBlock: + return getattr(self, 'tool-result') + + def __repr__(self) -> str: + return f'ToolResultData(tool_result={getattr(self, 'tool-result')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ToolResultData): + return NotImplemented + return getattr(self, 'tool-result') == getattr(other, 'tool-result') + + def __hash__(self) -> int: + return id(self) + + +class ToolStreamUpdateData: + """Payload for `tool-stream-update`.""" + def __init__( + self, + *, + data: str, + ) -> None: + setattr(self, 'data', data) + + def __repr__(self) -> str: + return f'ToolStreamUpdateData(data={getattr(self, 'data')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ToolStreamUpdateData): + return NotImplemented + return getattr(self, 'data') == getattr(other, 'data') + + def __hash__(self) -> int: + return id(self) + + +class ModelStreamUpdateData: + """Payload for `model-stream-update`.""" + def __init__( + self, + *, + event: str, + ) -> None: + setattr(self, 'event', event) + + def __repr__(self) -> str: + return f'ModelStreamUpdateData(event={getattr(self, 'event')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ModelStreamUpdateData): + return NotImplemented + return getattr(self, 'event') == getattr(other, 'event') + + def __hash__(self) -> int: + return id(self) + + +class InputRedaction: + """Input redaction emitted when a guardrail blocks input. Original is in history.""" + def __init__( + self, + *, + replace_content: str, + ) -> None: + setattr(self, 'replace-content', replace_content) + + @property + def replace_content(self) -> str: + return getattr(self, 'replace-content') + + def __repr__(self) -> str: + return f'InputRedaction(replace_content={getattr(self, 'replace-content')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, InputRedaction): + return NotImplemented + return getattr(self, 'replace-content') == getattr(other, 'replace-content') + + def __hash__(self) -> int: + return id(self) + + +class OutputRedaction: + """Output redaction emitted when a guardrail blocks output.""" + def __init__( + self, + *, + redacted_content: Optional[str], + replace_content: str, + ) -> None: + setattr(self, 'redacted-content', redacted_content) + setattr(self, 'replace-content', replace_content) + + @property + def redacted_content(self) -> Optional[str]: + return getattr(self, 'redacted-content') + + @property + def replace_content(self) -> str: + return getattr(self, 'replace-content') + + def __repr__(self) -> str: + return f'OutputRedaction(redacted_content={getattr(self, 'redacted-content')!r}, replace_content={getattr(self, 'replace-content')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, OutputRedaction): + return NotImplemented + return getattr(self, 'redacted-content') == getattr(other, 'redacted-content') and getattr(self, 'replace-content') == getattr(other, 'replace-content') + + def __hash__(self) -> int: + return id(self) + + +class RedactionEvent: + """Redaction event. Input and output fields are independent; at least one is set.""" + def __init__( + self, + *, + input_redaction: Optional[InputRedaction], + output_redaction: Optional[OutputRedaction], + ) -> None: + setattr(self, 'input-redaction', input_redaction) + setattr(self, 'output-redaction', output_redaction) + + @property + def input_redaction(self) -> Optional[InputRedaction]: + return getattr(self, 'input-redaction') + + @property + def output_redaction(self) -> Optional[OutputRedaction]: + return getattr(self, 'output-redaction') + + def __repr__(self) -> str: + return f'RedactionEvent(input_redaction={getattr(self, 'input-redaction')!r}, output_redaction={getattr(self, 'output-redaction')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, RedactionEvent): + return NotImplemented + return getattr(self, 'input-redaction') == getattr(other, 'input-redaction') and getattr(self, 'output-redaction') == getattr(other, 'output-redaction') + + def __hash__(self) -> int: + return id(self) + + +class StopEvent: + """Terminal event for a stream.""" + def __init__( + self, + *, + reason: StopReason, + usage: Optional[Usage], + metrics: Optional[Metrics], + structured_output: Optional[str], + ) -> None: + setattr(self, 'reason', reason) + setattr(self, 'usage', usage) + setattr(self, 'metrics', metrics) + setattr(self, 'structured-output', structured_output) + + @property + def structured_output(self) -> Optional[str]: + return getattr(self, 'structured-output') + + def __repr__(self) -> str: + return f'StopEvent(reason={getattr(self, 'reason')!r}, usage={getattr(self, 'usage')!r}, metrics={getattr(self, 'metrics')!r}, structured_output={getattr(self, 'structured-output')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, StopEvent): + return NotImplemented + return getattr(self, 'reason') == getattr(other, 'reason') and getattr(self, 'usage') == getattr(other, 'usage') and getattr(self, 'metrics') == getattr(other, 'metrics') and getattr(self, 'structured-output') == getattr(other, 'structured-output') + + def __hash__(self) -> int: + return id(self) + + +class AgentResultData: + """Payload for `agent-result`.""" + def __init__( + self, + *, + stop: StopEvent, + ) -> None: + setattr(self, 'stop', stop) + + def __repr__(self) -> str: + return f'AgentResultData(stop={getattr(self, 'stop')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AgentResultData): + return NotImplemented + return getattr(self, 'stop') == getattr(other, 'stop') + + def __hash__(self) -> int: + return id(self) + + +class StreamError: + """Why the agent loop surfaced an error mid-stream.""" + pass + +class StreamError_Model(StreamError, _WitVariantCase): + """A model call failed.""" + tag = 'model' + +class StreamError_Tool(StreamError, _WitVariantCase): + """A tool call failed.""" + tag = 'tool' + +class StreamError_ContextWindowExceeded(StreamError, _WitVariantCase): + """Input exceeded the model's context window and no conversation +manager could recover.""" + tag = 'context-window-exceeded' + +class StreamError_MaxTokensReached(StreamError, _WitVariantCase): + """Exceeded the model's max-tokens budget mid-response.""" + tag = 'max-tokens-reached' + +class StreamError_StructuredOutputUnavailable(StreamError, _WitVariantCase): + """Structured output was requested but the model never called the +tool, even after being forced.""" + tag = 'structured-output-unavailable' + +class StreamError_Internal(StreamError, _WitVariantCase): + """Catch-all for internal failures.""" + tag = 'internal' + +_StreamError_CASES: dict[str, type] = { + 'model': StreamError_Model, + 'tool': StreamError_Tool, + 'context-window-exceeded': StreamError_ContextWindowExceeded, + 'max-tokens-reached': StreamError_MaxTokensReached, + 'structured-output-unavailable': StreamError_StructuredOutputUnavailable, + 'internal': StreamError_Internal, +} + +def _StreamError_lift(raw: _WitVariant) -> StreamError: + cls = _StreamError_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown StreamError arm: {raw.tag!r}') + return cls(raw.payload) +StreamError.lift = staticmethod(_StreamError_lift) # type: ignore[attr-defined] + +class StreamEvent: + """Events yielded during agent streaming. +Hot-path arms: `text-delta`, `tool-use`, `tool-result`. Other content +blocks flow through `content`. Lifecycle arms (`before-invocation` +through `agent-result`) mirror a hook system and can be filtered by tag.""" + pass + +class StreamEvent_TextDelta(StreamEvent, _WitVariantCase): + """Incremental text from the model.""" + tag = 'text-delta' + +class StreamEvent_ToolUse(StreamEvent, _WitVariantCase): + """Model requested a tool call.""" + tag = 'tool-use' + +class StreamEvent_ToolResult(StreamEvent, _WitVariantCase): + """Tool call completed.""" + tag = 'tool-result' + +class StreamEvent_Content(StreamEvent, _WitVariantCase): + """Non-hot-path content block (image, reasoning, citations, etc).""" + tag = 'content' + +class StreamEvent_Metadata(StreamEvent, _WitVariantCase): + """Cumulative usage and metrics snapshot.""" + tag = 'metadata' + +class StreamEvent_Stop(StreamEvent, _WitVariantCase): + """Terminal event for the stream.""" + tag = 'stop' + +class StreamEvent_Redaction(StreamEvent, _WitVariantCase): + """Guardrail redaction fired.""" + tag = 'redaction' + +class StreamEvent_Error(StreamEvent, _WitVariantCase): + """Recoverable error surfaced mid-stream.""" + tag = 'error' + +class StreamEvent_Interrupt(StreamEvent, _WitVariantCase): + """Human-in-the-loop pause; resume via `response-stream.respond`.""" + tag = 'interrupt' + +class StreamEvent_Initialized(StreamEvent, _WitVariantCase): + """Agent finished construction.""" + tag = 'initialized' + +class StreamEvent_BeforeInvocation(StreamEvent, _WitVariantCase): + """About to process a user invocation.""" + tag = 'before-invocation' + +class StreamEvent_AfterInvocation(StreamEvent, _WitVariantCase): + """Finished processing a user invocation.""" + tag = 'after-invocation' + +class StreamEvent_MessageAdded(StreamEvent, _WitVariantCase): + """A message was appended to the conversation.""" + tag = 'message-added' + +class StreamEvent_BeforeModelCall(StreamEvent, _WitVariantCase): + """About to call the model.""" + tag = 'before-model-call' + +class StreamEvent_AfterModelCall(StreamEvent, _WitVariantCase): + """Model call returned.""" + tag = 'after-model-call' + +class StreamEvent_BeforeTools(StreamEvent, _WitVariantCase): + """About to run a batch of tool calls from one assistant turn.""" + tag = 'before-tools' + +class StreamEvent_AfterTools(StreamEvent, _WitVariantCase): + """Tool batch finished.""" + tag = 'after-tools' + +class StreamEvent_BeforeToolCall(StreamEvent, _WitVariantCase): + """About to call a single tool.""" + tag = 'before-tool-call' + +class StreamEvent_AfterToolCall(StreamEvent, _WitVariantCase): + """Tool call returned.""" + tag = 'after-tool-call' + +class StreamEvent_ContentBlock(StreamEvent, _WitVariantCase): + """A content block was assembled during streaming.""" + tag = 'content-block' + +class StreamEvent_ModelMessage(StreamEvent, _WitVariantCase): + """Model finished producing a full message.""" + tag = 'model-message' + +class StreamEvent_ToolResultHook(StreamEvent, _WitVariantCase): + """Tool finished execution (completion event, not streaming update).""" + tag = 'tool-result-hook' + +class StreamEvent_ToolUpdate(StreamEvent, _WitVariantCase): + """Streaming update from a tool.""" + tag = 'tool-update' + +class StreamEvent_ModelUpdate(StreamEvent, _WitVariantCase): + """Streaming update from the model.""" + tag = 'model-update' + +class StreamEvent_AgentResult(StreamEvent, _WitVariantCase): + """Final event for an invocation, carrying the terminal result.""" + tag = 'agent-result' + +_StreamEvent_CASES: dict[str, type] = { + 'text-delta': StreamEvent_TextDelta, + 'tool-use': StreamEvent_ToolUse, + 'tool-result': StreamEvent_ToolResult, + 'content': StreamEvent_Content, + 'metadata': StreamEvent_Metadata, + 'stop': StreamEvent_Stop, + 'redaction': StreamEvent_Redaction, + 'error': StreamEvent_Error, + 'interrupt': StreamEvent_Interrupt, + 'initialized': StreamEvent_Initialized, + 'before-invocation': StreamEvent_BeforeInvocation, + 'after-invocation': StreamEvent_AfterInvocation, + 'message-added': StreamEvent_MessageAdded, + 'before-model-call': StreamEvent_BeforeModelCall, + 'after-model-call': StreamEvent_AfterModelCall, + 'before-tools': StreamEvent_BeforeTools, + 'after-tools': StreamEvent_AfterTools, + 'before-tool-call': StreamEvent_BeforeToolCall, + 'after-tool-call': StreamEvent_AfterToolCall, + 'content-block': StreamEvent_ContentBlock, + 'model-message': StreamEvent_ModelMessage, + 'tool-result-hook': StreamEvent_ToolResultHook, + 'tool-update': StreamEvent_ToolUpdate, + 'model-update': StreamEvent_ModelUpdate, + 'agent-result': StreamEvent_AgentResult, +} + +def _StreamEvent_lift(raw: _WitVariant) -> StreamEvent: + cls = _StreamEvent_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown StreamEvent arm: {raw.tag!r}') + return cls(raw.payload) +StreamEvent.lift = staticmethod(_StreamEvent_lift) # type: ignore[attr-defined] + +class ModelEventStream: + """Pull-based stream of model events from a custom provider; host produces, guest reads.""" + # Wraps a wasmtime-py ResourceAny / ResourceHost handle. + # The runtime sets ._handle to the underlying resource and + # ._invoke to a callable that dispatches a method by WIT name. + + def __init__(self, handle: Any, invoke: Any = None) -> None: + self._handle = handle + self._invoke = invoke + + def read(self) -> Optional[StreamEvent]: + return self._invoke('[method]model-event-stream.read', (self._handle,)) + + +class ModelStreamOptions: + """Options passed alongside the messages on each streaming call.""" + def __init__( + self, + *, + system_prompt: Optional[PromptInput], + tools: Optional[list[ToolSpec]], + tool_choice: Optional[ToolChoice], + ) -> None: + setattr(self, 'system-prompt', system_prompt) + setattr(self, 'tools', tools) + setattr(self, 'tool-choice', tool_choice) + + @property + def system_prompt(self) -> Optional[PromptInput]: + return getattr(self, 'system-prompt') + + @property + def tool_choice(self) -> Optional[ToolChoice]: + return getattr(self, 'tool-choice') + + def __repr__(self) -> str: + return f'ModelStreamOptions(system_prompt={getattr(self, 'system-prompt')!r}, tools={getattr(self, 'tools')!r}, tool_choice={getattr(self, 'tool-choice')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ModelStreamOptions): + return NotImplemented + return getattr(self, 'system-prompt') == getattr(other, 'system-prompt') and getattr(self, 'tools') == getattr(other, 'tools') and getattr(self, 'tool-choice') == getattr(other, 'tool-choice') + + def __hash__(self) -> int: + return id(self) + + +class StartStreamArgs: + """Arguments for `start-stream`.""" + def __init__( + self, + *, + provider_id: str, + messages: list[Message], + options: ModelStreamOptions, + ) -> None: + setattr(self, 'provider-id', provider_id) + setattr(self, 'messages', messages) + setattr(self, 'options', options) + + @property + def provider_id(self) -> str: + return getattr(self, 'provider-id') + + def __repr__(self) -> str: + return f'StartStreamArgs(provider_id={getattr(self, 'provider-id')!r}, messages={getattr(self, 'messages')!r}, options={getattr(self, 'options')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, StartStreamArgs): + return NotImplemented + return getattr(self, 'provider-id') == getattr(other, 'provider-id') and getattr(self, 'messages') == getattr(other, 'messages') and getattr(self, 'options') == getattr(other, 'options') + + def __hash__(self) -> int: + return id(self) + + +class CountTokensArgs: + """Arguments for `count-tokens`.""" + def __init__( + self, + *, + provider_id: str, + messages: list[Message], + system_prompt: Optional[PromptInput], + tools: Optional[list[ToolSpec]], + ) -> None: + setattr(self, 'provider-id', provider_id) + setattr(self, 'messages', messages) + setattr(self, 'system-prompt', system_prompt) + setattr(self, 'tools', tools) + + @property + def provider_id(self) -> str: + return getattr(self, 'provider-id') + + @property + def system_prompt(self) -> Optional[PromptInput]: + return getattr(self, 'system-prompt') + + def __repr__(self) -> str: + return f'CountTokensArgs(provider_id={getattr(self, 'provider-id')!r}, messages={getattr(self, 'messages')!r}, system_prompt={getattr(self, 'system-prompt')!r}, tools={getattr(self, 'tools')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CountTokensArgs): + return NotImplemented + return getattr(self, 'provider-id') == getattr(other, 'provider-id') and getattr(self, 'messages') == getattr(other, 'messages') and getattr(self, 'system-prompt') == getattr(other, 'system-prompt') and getattr(self, 'tools') == getattr(other, 'tools') + + def __hash__(self) -> int: + return id(self) + + +class OrchestrationStatus(str): + """Lifecycle status of a node or overall run.""" + __slots__ = () + + PENDING: 'OrchestrationStatus' + EXECUTING: 'OrchestrationStatus' + COMPLETED: 'OrchestrationStatus' + FAILED: 'OrchestrationStatus' + CANCELLED: 'OrchestrationStatus' + +OrchestrationStatus.PENDING = OrchestrationStatus('pending') # type: ignore[attr-defined] +OrchestrationStatus.EXECUTING = OrchestrationStatus('executing') # type: ignore[attr-defined] +OrchestrationStatus.COMPLETED = OrchestrationStatus('completed') # type: ignore[attr-defined] +OrchestrationStatus.FAILED = OrchestrationStatus('failed') # type: ignore[attr-defined] +OrchestrationStatus.CANCELLED = OrchestrationStatus('cancelled') # type: ignore[attr-defined] + + +class TerminalStatus(str): + """Terminal status of a node or run.""" + __slots__ = () + + COMPLETED: 'TerminalStatus' + FAILED: 'TerminalStatus' + CANCELLED: 'TerminalStatus' + +TerminalStatus.COMPLETED = TerminalStatus('completed') # type: ignore[attr-defined] +TerminalStatus.FAILED = TerminalStatus('failed') # type: ignore[attr-defined] +TerminalStatus.CANCELLED = TerminalStatus('cancelled') # type: ignore[attr-defined] + + +class NodeKind(str): + """What a node is.""" + __slots__ = () + + AGENT: 'NodeKind' + MULTI_AGENT: 'NodeKind' + +NodeKind.AGENT = NodeKind('agent') # type: ignore[attr-defined] +NodeKind.MULTI_AGENT = NodeKind('multi-agent') # type: ignore[attr-defined] + + +class AgentNodeConfig: + """Definition of an agent-backed node.""" + def __init__( + self, + *, + id: str, + description: Optional[str], + timeout: Optional[int], + agent_config: str, + ) -> None: + setattr(self, 'id', id) + setattr(self, 'description', description) + setattr(self, 'timeout', timeout) + setattr(self, 'agent-config', agent_config) + + @property + def agent_config(self) -> str: + return getattr(self, 'agent-config') + + def __repr__(self) -> str: + return f'AgentNodeConfig(id={getattr(self, 'id')!r}, description={getattr(self, 'description')!r}, timeout={getattr(self, 'timeout')!r}, agent_config={getattr(self, 'agent-config')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AgentNodeConfig): + return NotImplemented + return getattr(self, 'id') == getattr(other, 'id') and getattr(self, 'description') == getattr(other, 'description') and getattr(self, 'timeout') == getattr(other, 'timeout') and getattr(self, 'agent-config') == getattr(other, 'agent-config') + + def __hash__(self) -> int: + return id(self) + + +class MultiAgentNodeConfig: + """Definition of a node that wraps another orchestrator.""" + def __init__( + self, + *, + id: str, + description: Optional[str], + orchestrator: str, + ) -> None: + setattr(self, 'id', id) + setattr(self, 'description', description) + setattr(self, 'orchestrator', orchestrator) + + def __repr__(self) -> str: + return f'MultiAgentNodeConfig(id={getattr(self, 'id')!r}, description={getattr(self, 'description')!r}, orchestrator={getattr(self, 'orchestrator')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, MultiAgentNodeConfig): + return NotImplemented + return getattr(self, 'id') == getattr(other, 'id') and getattr(self, 'description') == getattr(other, 'description') and getattr(self, 'orchestrator') == getattr(other, 'orchestrator') + + def __hash__(self) -> int: + return id(self) + + +class NodeConfig: + """Any node a graph or swarm can execute.""" + pass + +class NodeConfig_Agent(NodeConfig, _WitVariantCase): + """Wraps a single agent.""" + tag = 'agent' + +class NodeConfig_MultiAgent(NodeConfig, _WitVariantCase): + """Wraps a nested orchestrator.""" + tag = 'multi-agent' + +_NodeConfig_CASES: dict[str, type] = { + 'agent': NodeConfig_Agent, + 'multi-agent': NodeConfig_MultiAgent, +} + +def _NodeConfig_lift(raw: _WitVariant) -> NodeConfig: + cls = _NodeConfig_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown NodeConfig arm: {raw.tag!r}') + return cls(raw.payload) +NodeConfig.lift = staticmethod(_NodeConfig_lift) # type: ignore[attr-defined] + +class EdgeHandler: + """Condition attached to a graph edge.""" + def __init__( + self, + *, + handler_id: str, + ) -> None: + setattr(self, 'handler-id', handler_id) + + @property + def handler_id(self) -> str: + return getattr(self, 'handler-id') + + def __repr__(self) -> str: + return f'EdgeHandler(handler_id={getattr(self, 'handler-id')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, EdgeHandler): + return NotImplemented + return getattr(self, 'handler-id') == getattr(other, 'handler-id') + + def __hash__(self) -> int: + return id(self) + + +class EdgeConfig: + """Edge connecting two graph nodes.""" + def __init__( + self, + *, + source: str, + target: str, + handler: Optional[EdgeHandler], + ) -> None: + setattr(self, 'source', source) + setattr(self, 'target', target) + setattr(self, 'handler', handler) + + def __repr__(self) -> str: + return f'EdgeConfig(source={getattr(self, 'source')!r}, target={getattr(self, 'target')!r}, handler={getattr(self, 'handler')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, EdgeConfig): + return NotImplemented + return getattr(self, 'source') == getattr(other, 'source') and getattr(self, 'target') == getattr(other, 'target') and getattr(self, 'handler') == getattr(other, 'handler') + + def __hash__(self) -> int: + return id(self) + + +class GraphConfig: + """Runtime configuration for a Graph.""" + def __init__( + self, + *, + id: str, + nodes: list[NodeConfig], + edges: list[EdgeConfig], + sources: list[str], + max_concurrency: Optional[int], + max_steps: Optional[int], + timeout: Optional[int], + node_timeout: Optional[int], + ) -> None: + setattr(self, 'id', id) + setattr(self, 'nodes', nodes) + setattr(self, 'edges', edges) + setattr(self, 'sources', sources) + setattr(self, 'max-concurrency', max_concurrency) + setattr(self, 'max-steps', max_steps) + setattr(self, 'timeout', timeout) + setattr(self, 'node-timeout', node_timeout) + + @property + def max_concurrency(self) -> Optional[int]: + return getattr(self, 'max-concurrency') + + @property + def max_steps(self) -> Optional[int]: + return getattr(self, 'max-steps') + + @property + def node_timeout(self) -> Optional[int]: + return getattr(self, 'node-timeout') + + def __repr__(self) -> str: + return f'GraphConfig(id={getattr(self, 'id')!r}, nodes={getattr(self, 'nodes')!r}, edges={getattr(self, 'edges')!r}, sources={getattr(self, 'sources')!r}, max_concurrency={getattr(self, 'max-concurrency')!r}, max_steps={getattr(self, 'max-steps')!r}, timeout={getattr(self, 'timeout')!r}, node_timeout={getattr(self, 'node-timeout')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, GraphConfig): + return NotImplemented + return getattr(self, 'id') == getattr(other, 'id') and getattr(self, 'nodes') == getattr(other, 'nodes') and getattr(self, 'edges') == getattr(other, 'edges') and getattr(self, 'sources') == getattr(other, 'sources') and getattr(self, 'max-concurrency') == getattr(other, 'max-concurrency') and getattr(self, 'max-steps') == getattr(other, 'max-steps') and getattr(self, 'timeout') == getattr(other, 'timeout') and getattr(self, 'node-timeout') == getattr(other, 'node-timeout') + + def __hash__(self) -> int: + return id(self) + + +class SwarmConfig: + """Runtime configuration for a Swarm.""" + def __init__( + self, + *, + id: str, + nodes: list[AgentNodeConfig], + start_node_id: str, + max_steps: Optional[int], + timeout: Optional[int], + node_timeout: Optional[int], + ) -> None: + setattr(self, 'id', id) + setattr(self, 'nodes', nodes) + setattr(self, 'start-node-id', start_node_id) + setattr(self, 'max-steps', max_steps) + setattr(self, 'timeout', timeout) + setattr(self, 'node-timeout', node_timeout) + + @property + def start_node_id(self) -> str: + return getattr(self, 'start-node-id') + + @property + def max_steps(self) -> Optional[int]: + return getattr(self, 'max-steps') + + @property + def node_timeout(self) -> Optional[int]: + return getattr(self, 'node-timeout') + + def __repr__(self) -> str: + return f'SwarmConfig(id={getattr(self, 'id')!r}, nodes={getattr(self, 'nodes')!r}, start_node_id={getattr(self, 'start-node-id')!r}, max_steps={getattr(self, 'max-steps')!r}, timeout={getattr(self, 'timeout')!r}, node_timeout={getattr(self, 'node-timeout')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SwarmConfig): + return NotImplemented + return getattr(self, 'id') == getattr(other, 'id') and getattr(self, 'nodes') == getattr(other, 'nodes') and getattr(self, 'start-node-id') == getattr(other, 'start-node-id') and getattr(self, 'max-steps') == getattr(other, 'max-steps') and getattr(self, 'timeout') == getattr(other, 'timeout') and getattr(self, 'node-timeout') == getattr(other, 'node-timeout') + + def __hash__(self) -> int: + return id(self) + + +class NodeError: + """Why a node or run ended in `failed` status.""" + pass + +class NodeError_Execution(NodeError, _WitVariantCase): + """An underlying agent or nested orchestrator failed.""" + tag = 'execution' + +class NodeError_Timeout(NodeError, _WitVariantCase): + """Wall-clock ceiling was exceeded.""" + tag = 'timeout' + +class NodeError_LimitExceeded(NodeError, _WitVariantCase): + """A declared runtime limit (max-steps, max-concurrency) was hit.""" + tag = 'limit-exceeded' + +class NodeError_EdgeHandler(NodeError, _WitVariantCase): + """Edge handler rejected the traversal with an error.""" + tag = 'edge-handler' + +class NodeError_InvalidConfig(NodeError, _WitVariantCase): + """Invalid configuration detected at run time.""" + tag = 'invalid-config' + +class NodeError_Internal(NodeError, _WitVariantCase): + """Catch-all for internal failures.""" + tag = 'internal' + +_NodeError_CASES: dict[str, type] = { + 'execution': NodeError_Execution, + 'timeout': NodeError_Timeout, + 'limit-exceeded': NodeError_LimitExceeded, + 'edge-handler': NodeError_EdgeHandler, + 'invalid-config': NodeError_InvalidConfig, + 'internal': NodeError_Internal, +} + +def _NodeError_lift(raw: _WitVariant) -> NodeError: + cls = _NodeError_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown NodeError arm: {raw.tag!r}') + return cls(raw.payload) +NodeError.lift = staticmethod(_NodeError_lift) # type: ignore[attr-defined] + +class NodeResult: + """Result of a single node execution.""" + def __init__( + self, + *, + node_id: str, + status: TerminalStatus, + duration: int, + content: list[ContentBlock], + error: Optional[NodeError], + structured_output: Optional[str], + usage: Optional[Usage], + metrics: Optional[Metrics], + ) -> None: + setattr(self, 'node-id', node_id) + setattr(self, 'status', status) + setattr(self, 'duration', duration) + setattr(self, 'content', content) + setattr(self, 'error', error) + setattr(self, 'structured-output', structured_output) + setattr(self, 'usage', usage) + setattr(self, 'metrics', metrics) + + @property + def node_id(self) -> str: + return getattr(self, 'node-id') + + @property + def structured_output(self) -> Optional[str]: + return getattr(self, 'structured-output') + + def __repr__(self) -> str: + return f'NodeResult(node_id={getattr(self, 'node-id')!r}, status={getattr(self, 'status')!r}, duration={getattr(self, 'duration')!r}, content={getattr(self, 'content')!r}, error={getattr(self, 'error')!r}, structured_output={getattr(self, 'structured-output')!r}, usage={getattr(self, 'usage')!r}, metrics={getattr(self, 'metrics')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, NodeResult): + return NotImplemented + return getattr(self, 'node-id') == getattr(other, 'node-id') and getattr(self, 'status') == getattr(other, 'status') and getattr(self, 'duration') == getattr(other, 'duration') and getattr(self, 'content') == getattr(other, 'content') and getattr(self, 'error') == getattr(other, 'error') and getattr(self, 'structured-output') == getattr(other, 'structured-output') and getattr(self, 'usage') == getattr(other, 'usage') and getattr(self, 'metrics') == getattr(other, 'metrics') + + def __hash__(self) -> int: + return id(self) + + +class MultiAgentResult: + """Final result of a graph or swarm run.""" + def __init__( + self, + *, + status: TerminalStatus, + nodes: list[NodeResult], + duration: int, + usage: Optional[Usage], + metrics: Optional[Metrics], + ) -> None: + setattr(self, 'status', status) + setattr(self, 'nodes', nodes) + setattr(self, 'duration', duration) + setattr(self, 'usage', usage) + setattr(self, 'metrics', metrics) + + def __repr__(self) -> str: + return f'MultiAgentResult(status={getattr(self, 'status')!r}, nodes={getattr(self, 'nodes')!r}, duration={getattr(self, 'duration')!r}, usage={getattr(self, 'usage')!r}, metrics={getattr(self, 'metrics')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, MultiAgentResult): + return NotImplemented + return getattr(self, 'status') == getattr(other, 'status') and getattr(self, 'nodes') == getattr(other, 'nodes') and getattr(self, 'duration') == getattr(other, 'duration') and getattr(self, 'usage') == getattr(other, 'usage') and getattr(self, 'metrics') == getattr(other, 'metrics') + + def __hash__(self) -> int: + return id(self) + + +class MultiAgentInvokeArgs: + """Arguments for invoking a graph or swarm.""" + def __init__( + self, + *, + input: PromptInput, + invocation_state: Optional[str], + ) -> None: + setattr(self, 'input', input) + setattr(self, 'invocation-state', invocation_state) + + @property + def invocation_state(self) -> Optional[str]: + return getattr(self, 'invocation-state') + + def __repr__(self) -> str: + return f'MultiAgentInvokeArgs(input={getattr(self, 'input')!r}, invocation_state={getattr(self, 'invocation-state')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, MultiAgentInvokeArgs): + return NotImplemented + return getattr(self, 'input') == getattr(other, 'input') and getattr(self, 'invocation-state') == getattr(other, 'invocation-state') + + def __hash__(self) -> int: + return id(self) + + +class NodeStartData: + """Payload for `node-start`.""" + def __init__( + self, + *, + node_id: str, + kind: NodeKind, + ) -> None: + setattr(self, 'node-id', node_id) + setattr(self, 'kind', kind) + + @property + def node_id(self) -> str: + return getattr(self, 'node-id') + + def __repr__(self) -> str: + return f'NodeStartData(node_id={getattr(self, 'node-id')!r}, kind={getattr(self, 'kind')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, NodeStartData): + return NotImplemented + return getattr(self, 'node-id') == getattr(other, 'node-id') and getattr(self, 'kind') == getattr(other, 'kind') + + def __hash__(self) -> int: + return id(self) + + +class NodeEventData: + """Payload for `node-event`. Carries a nested stream event from a +running node.""" + def __init__( + self, + *, + node_id: str, + event: StreamEvent, + ) -> None: + setattr(self, 'node-id', node_id) + setattr(self, 'event', event) + + @property + def node_id(self) -> str: + return getattr(self, 'node-id') + + def __repr__(self) -> str: + return f'NodeEventData(node_id={getattr(self, 'node-id')!r}, event={getattr(self, 'event')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, NodeEventData): + return NotImplemented + return getattr(self, 'node-id') == getattr(other, 'node-id') and getattr(self, 'event') == getattr(other, 'event') + + def __hash__(self) -> int: + return id(self) + + +class HandoffEvent: + """Payload for a handoff edge firing.""" + def __init__( + self, + *, + from_node_ids: list[str], + to_node_ids: list[str], + ) -> None: + setattr(self, 'from-node-ids', from_node_ids) + setattr(self, 'to-node-ids', to_node_ids) + + @property + def from_node_ids(self) -> list[str]: + return getattr(self, 'from-node-ids') + + @property + def to_node_ids(self) -> list[str]: + return getattr(self, 'to-node-ids') + + def __repr__(self) -> str: + return f'HandoffEvent(from_node_ids={getattr(self, 'from-node-ids')!r}, to_node_ids={getattr(self, 'to-node-ids')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, HandoffEvent): + return NotImplemented + return getattr(self, 'from-node-ids') == getattr(other, 'from-node-ids') and getattr(self, 'to-node-ids') == getattr(other, 'to-node-ids') + + def __hash__(self) -> int: + return id(self) + + +class MultiAgentStreamEvent: + """Events emitted while streaming a multi-agent run.""" + pass + +class MultiAgentStreamEvent_NodeStart(MultiAgentStreamEvent, _WitVariantCase): + """A node began executing.""" + tag = 'node-start' + +class MultiAgentStreamEvent_Nested(MultiAgentStreamEvent, _WitVariantCase): + """A nested stream event from a running node.""" + tag = 'nested' + +class MultiAgentStreamEvent_NodeStop(MultiAgentStreamEvent, _WitVariantCase): + """A node finished executing.""" + tag = 'node-stop' + +class MultiAgentStreamEvent_Handoff(MultiAgentStreamEvent, _WitVariantCase): + """A handoff happened between nodes.""" + tag = 'handoff' + +class MultiAgentStreamEvent_RunComplete(MultiAgentStreamEvent, _WitVariantCase): + """Terminal result for the run.""" + tag = 'run-complete' + +_MultiAgentStreamEvent_CASES: dict[str, type] = { + 'node-start': MultiAgentStreamEvent_NodeStart, + 'nested': MultiAgentStreamEvent_Nested, + 'node-stop': MultiAgentStreamEvent_NodeStop, + 'handoff': MultiAgentStreamEvent_Handoff, + 'run-complete': MultiAgentStreamEvent_RunComplete, +} + +def _MultiAgentStreamEvent_lift(raw: _WitVariant) -> MultiAgentStreamEvent: + cls = _MultiAgentStreamEvent_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown MultiAgentStreamEvent arm: {raw.tag!r}') + return cls(raw.payload) +MultiAgentStreamEvent.lift = staticmethod(_MultiAgentStreamEvent_lift) # type: ignore[attr-defined] + +class EdgeHandlerError: + """Why an edge evaluation failed.""" + pass + +class EdgeHandlerError_Unknown(EdgeHandlerError, _WitVariantCase): + """No handler registered for the given id.""" + tag = 'unknown' + +class EdgeHandlerError_Failed(EdgeHandlerError, _WitVariantCase): + """Handler raised an exception.""" + tag = 'failed' + +_EdgeHandlerError_CASES: dict[str, type] = { + 'unknown': EdgeHandlerError_Unknown, + 'failed': EdgeHandlerError_Failed, +} + +def _EdgeHandlerError_lift(raw: _WitVariant) -> EdgeHandlerError: + cls = _EdgeHandlerError_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown EdgeHandlerError arm: {raw.tag!r}') + return cls(raw.payload) +EdgeHandlerError.lift = staticmethod(_EdgeHandlerError_lift) # type: ignore[attr-defined] + +class HandlerState: + """State snapshot passed to `evaluate` so the handler can branch on +prior node results.""" + def __init__( + self, + *, + results: list[NodeResult], + execution_count: int, + ) -> None: + setattr(self, 'results', results) + setattr(self, 'execution-count', execution_count) + + @property + def execution_count(self) -> int: + return getattr(self, 'execution-count') + + def __repr__(self) -> str: + return f'HandlerState(results={getattr(self, 'results')!r}, execution_count={getattr(self, 'execution-count')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, HandlerState): + return NotImplemented + return getattr(self, 'results') == getattr(other, 'results') and getattr(self, 'execution-count') == getattr(other, 'execution-count') + + def __hash__(self) -> int: + return id(self) + + +class BashToolConfig: + """Bash tool configuration.""" + def __init__( + self, + *, + default_timeout_s: Optional[int], + ) -> None: + setattr(self, 'default-timeout-s', default_timeout_s) + + @property + def default_timeout_s(self) -> Optional[int]: + return getattr(self, 'default-timeout-s') + + def __repr__(self) -> str: + return f'BashToolConfig(default_timeout_s={getattr(self, 'default-timeout-s')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, BashToolConfig): + return NotImplemented + return getattr(self, 'default-timeout-s') == getattr(other, 'default-timeout-s') + + def __hash__(self) -> int: + return id(self) + + +class FileEditorToolConfig: + """File editor tool configuration.""" + def __init__( + self, + *, + workspace_root: Optional[str], + ) -> None: + setattr(self, 'workspace-root', workspace_root) + + @property + def workspace_root(self) -> Optional[str]: + return getattr(self, 'workspace-root') + + def __repr__(self) -> str: + return f'FileEditorToolConfig(workspace_root={getattr(self, 'workspace-root')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, FileEditorToolConfig): + return NotImplemented + return getattr(self, 'workspace-root') == getattr(other, 'workspace-root') + + def __hash__(self) -> int: + return id(self) + + +class HttpRequestToolConfig: + """HTTP request tool configuration.""" + def __init__( + self, + *, + allowed_hosts: list[str], + max_response_bytes: int, + ) -> None: + setattr(self, 'allowed-hosts', allowed_hosts) + setattr(self, 'max-response-bytes', max_response_bytes) + + @property + def allowed_hosts(self) -> list[str]: + return getattr(self, 'allowed-hosts') + + @property + def max_response_bytes(self) -> int: + return getattr(self, 'max-response-bytes') + + def __repr__(self) -> str: + return f'HttpRequestToolConfig(allowed_hosts={getattr(self, 'allowed-hosts')!r}, max_response_bytes={getattr(self, 'max-response-bytes')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, HttpRequestToolConfig): + return NotImplemented + return getattr(self, 'allowed-hosts') == getattr(other, 'allowed-hosts') and getattr(self, 'max-response-bytes') == getattr(other, 'max-response-bytes') + + def __hash__(self) -> int: + return id(self) + + +class NotebookToolConfig: + """Notebook tool configuration.""" + def __init__( + self, + *, + workspace_root: Optional[str], + ) -> None: + setattr(self, 'workspace-root', workspace_root) + + @property + def workspace_root(self) -> Optional[str]: + return getattr(self, 'workspace-root') + + def __repr__(self) -> str: + return f'NotebookToolConfig(workspace_root={getattr(self, 'workspace-root')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, NotebookToolConfig): + return NotImplemented + return getattr(self, 'workspace-root') == getattr(other, 'workspace-root') + + def __hash__(self) -> int: + return id(self) + + +class VendedTool: + """Built-in tools.""" + pass + +class VendedTool_Bash(VendedTool, _WitVariantCase): + """Run shell commands in a persistent bash session.""" + tag = 'bash' + +class VendedTool_FileEditor(VendedTool, _WitVariantCase): + """Create, view, and edit files on disk.""" + tag = 'file-editor' + +class VendedTool_HttpRequest(VendedTool, _WitVariantCase): + """Make HTTP requests.""" + tag = 'http-request' + +class VendedTool_Notebook(VendedTool, _WitVariantCase): + """Read and execute Jupyter notebook cells.""" + tag = 'notebook' + +_VendedTool_CASES: dict[str, type] = { + 'bash': VendedTool_Bash, + 'file-editor': VendedTool_FileEditor, + 'http-request': VendedTool_HttpRequest, + 'notebook': VendedTool_Notebook, +} + +def _VendedTool_lift(raw: _WitVariant) -> VendedTool: + cls = _VendedTool_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown VendedTool arm: {raw.tag!r}') + return cls(raw.payload) +VendedTool.lift = staticmethod(_VendedTool_lift) # type: ignore[attr-defined] + +class SkillSource: + """Location of a skill definition on disk.""" + def __init__( + self, + *, + path: str, + ) -> None: + setattr(self, 'path', path) + + def __repr__(self) -> str: + return f'SkillSource(path={getattr(self, 'path')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SkillSource): + return NotImplemented + return getattr(self, 'path') == getattr(other, 'path') + + def __hash__(self) -> int: + return id(self) + + +class SkillsPluginConfig: + """Skills plugin configuration.""" + def __init__( + self, + *, + skills: list[SkillSource], + strict: bool, + max_resource_files: Optional[int], + state_key: Optional[str], + ) -> None: + setattr(self, 'skills', skills) + setattr(self, 'strict', strict) + setattr(self, 'max-resource-files', max_resource_files) + setattr(self, 'state-key', state_key) + + @property + def max_resource_files(self) -> Optional[int]: + return getattr(self, 'max-resource-files') + + @property + def state_key(self) -> Optional[str]: + return getattr(self, 'state-key') + + def __repr__(self) -> str: + return f'SkillsPluginConfig(skills={getattr(self, 'skills')!r}, strict={getattr(self, 'strict')!r}, max_resource_files={getattr(self, 'max-resource-files')!r}, state_key={getattr(self, 'state-key')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SkillsPluginConfig): + return NotImplemented + return getattr(self, 'skills') == getattr(other, 'skills') and getattr(self, 'strict') == getattr(other, 'strict') and getattr(self, 'max-resource-files') == getattr(other, 'max-resource-files') and getattr(self, 'state-key') == getattr(other, 'state-key') + + def __hash__(self) -> int: + return id(self) + + +class ContextOffloaderPluginConfig: + """Context offloader plugin configuration.""" + def __init__( + self, + *, + max_result_tokens: Optional[int], + preview_tokens: Optional[int], + include_retrieval_tool: bool, + ) -> None: + setattr(self, 'max-result-tokens', max_result_tokens) + setattr(self, 'preview-tokens', preview_tokens) + setattr(self, 'include-retrieval-tool', include_retrieval_tool) + + @property + def max_result_tokens(self) -> Optional[int]: + return getattr(self, 'max-result-tokens') + + @property + def preview_tokens(self) -> Optional[int]: + return getattr(self, 'preview-tokens') + + @property + def include_retrieval_tool(self) -> bool: + return getattr(self, 'include-retrieval-tool') + + def __repr__(self) -> str: + return f'ContextOffloaderPluginConfig(max_result_tokens={getattr(self, 'max-result-tokens')!r}, preview_tokens={getattr(self, 'preview-tokens')!r}, include_retrieval_tool={getattr(self, 'include-retrieval-tool')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ContextOffloaderPluginConfig): + return NotImplemented + return getattr(self, 'max-result-tokens') == getattr(other, 'max-result-tokens') and getattr(self, 'preview-tokens') == getattr(other, 'preview-tokens') and getattr(self, 'include-retrieval-tool') == getattr(other, 'include-retrieval-tool') + + def __hash__(self) -> int: + return id(self) + + +class VendedPlugin: + """Built-in plugins.""" + pass + +class VendedPlugin_Skills(VendedPlugin, _WitVariantCase): + """Load and activate Anthropic-style skills from disk.""" + tag = 'skills' + +class VendedPlugin_ContextOffloader(VendedPlugin, _WitVariantCase): + """Offload large tool results to external storage.""" + tag = 'context-offloader' + +_VendedPlugin_CASES: dict[str, type] = { + 'skills': VendedPlugin_Skills, + 'context-offloader': VendedPlugin_ContextOffloader, +} + +def _VendedPlugin_lift(raw: _WitVariant) -> VendedPlugin: + cls = _VendedPlugin_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown VendedPlugin arm: {raw.tag!r}') + return cls(raw.payload) +VendedPlugin.lift = staticmethod(_VendedPlugin_lift) # type: ignore[attr-defined] + +class ConcurrentOptions: + """Concurrent-execution options.""" + def __init__( + self, + *, + max_concurrency: Optional[int], + ) -> None: + setattr(self, 'max-concurrency', max_concurrency) + + @property + def max_concurrency(self) -> Optional[int]: + return getattr(self, 'max-concurrency') + + def __repr__(self) -> str: + return f'ConcurrentOptions(max_concurrency={getattr(self, 'max-concurrency')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ConcurrentOptions): + return NotImplemented + return getattr(self, 'max-concurrency') == getattr(other, 'max-concurrency') + + def __hash__(self) -> int: + return id(self) + + +ToolExecutorStrategy = None | ConcurrentOptions +"""Strategy for executing tool calls emitted in a single assistant turn.""" + +AttributeValue = str | int | float | bool +"""Scalar attribute value attached to a trace.""" + +class TraceAttribute: + """Key-value pair attached to every OpenTelemetry span the agent emits. +Distinct from `streaming.trace-metadata-entry`, which is string-only.""" + def __init__( + self, + *, + key: str, + value: AttributeValue, + ) -> None: + setattr(self, 'key', key) + setattr(self, 'value', value) + + def __repr__(self) -> str: + return f'TraceAttribute(key={getattr(self, 'key')!r}, value={getattr(self, 'value')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TraceAttribute): + return NotImplemented + return getattr(self, 'key') == getattr(other, 'key') and getattr(self, 'value') == getattr(other, 'value') + + def __hash__(self) -> int: + return id(self) + + +class TraceContext: + """W3C Trace Context headers linking the agent's spans to a caller's trace.""" + def __init__( + self, + *, + traceparent: str, + tracestate: Optional[str], + ) -> None: + setattr(self, 'traceparent', traceparent) + setattr(self, 'tracestate', tracestate) + + def __repr__(self) -> str: + return f'TraceContext(traceparent={getattr(self, 'traceparent')!r}, tracestate={getattr(self, 'tracestate')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TraceContext): + return NotImplemented + return getattr(self, 'traceparent') == getattr(other, 'traceparent') and getattr(self, 'tracestate') == getattr(other, 'tracestate') + + def __hash__(self) -> int: + return id(self) + + +class AgentIdentity: + """Display-level identity of the agent; all fields default to sensible values.""" + def __init__( + self, + *, + name: Optional[str], + id: Optional[str], + description: Optional[str], + ) -> None: + setattr(self, 'name', name) + setattr(self, 'id', id) + setattr(self, 'description', description) + + def __repr__(self) -> str: + return f'AgentIdentity(name={getattr(self, 'name')!r}, id={getattr(self, 'id')!r}, description={getattr(self, 'description')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AgentIdentity): + return NotImplemented + return getattr(self, 'name') == getattr(other, 'name') and getattr(self, 'id') == getattr(other, 'id') and getattr(self, 'description') == getattr(other, 'description') + + def __hash__(self) -> int: + return id(self) + + +class AgentConfig: + """Configuration passed to the `agent` constructor. +Invalid config surfaces on the first `generate` as `invalid-input`.""" + def __init__( + self, + *, + model: Optional[ModelConfig], + model_params: Optional[ModelParams], + messages: Optional[list[Message]], + system_prompt: Optional[PromptInput], + tools: Optional[list[ToolSpec]], + agent_tools: Optional[list[AgentAsToolConfig]], + vended_tools: Optional[list[VendedTool]], + vended_plugins: Optional[list[VendedPlugin]], + mcp_clients: Optional[list[McpClientConfig]], + identity: Optional[AgentIdentity], + tool_executor: Optional[ToolExecutorStrategy], + display_output: Optional[bool], + trace_attributes: Optional[list[TraceAttribute]], + trace_context: Optional[TraceContext], + session: Optional[SessionConfig], + conversation_manager: Optional[ConversationManagerConfig], + retry: Optional[RetryConfig], + structured_output_schema: Optional[str], + app_state: Optional[str], + model_state: Optional[str], + ) -> None: + setattr(self, 'model', model) + setattr(self, 'model-params', model_params) + setattr(self, 'messages', messages) + setattr(self, 'system-prompt', system_prompt) + setattr(self, 'tools', tools) + setattr(self, 'agent-tools', agent_tools) + setattr(self, 'vended-tools', vended_tools) + setattr(self, 'vended-plugins', vended_plugins) + setattr(self, 'mcp-clients', mcp_clients) + setattr(self, 'identity', identity) + setattr(self, 'tool-executor', (_WitVariant('none') if tool_executor is None else _WitVariant('some', tool_executor))) + setattr(self, 'display-output', display_output) + setattr(self, 'trace-attributes', trace_attributes) + setattr(self, 'trace-context', trace_context) + setattr(self, 'session', session) + setattr(self, 'conversation-manager', conversation_manager) + setattr(self, 'retry', retry) + setattr(self, 'structured-output-schema', structured_output_schema) + setattr(self, 'app-state', app_state) + setattr(self, 'model-state', model_state) + + @property + def model_params(self) -> Optional[ModelParams]: + return getattr(self, 'model-params') + + @property + def system_prompt(self) -> Optional[PromptInput]: + return getattr(self, 'system-prompt') + + @property + def agent_tools(self) -> Optional[list[AgentAsToolConfig]]: + return getattr(self, 'agent-tools') + + @property + def vended_tools(self) -> Optional[list[VendedTool]]: + return getattr(self, 'vended-tools') + + @property + def vended_plugins(self) -> Optional[list[VendedPlugin]]: + return getattr(self, 'vended-plugins') + + @property + def mcp_clients(self) -> Optional[list[McpClientConfig]]: + return getattr(self, 'mcp-clients') + + @property + def tool_executor(self) -> Optional[ToolExecutorStrategy]: + return (None if getattr(self, 'tool-executor').tag == 'none' else getattr(self, 'tool-executor').payload) + + @property + def display_output(self) -> Optional[bool]: + return getattr(self, 'display-output') + + @property + def trace_attributes(self) -> Optional[list[TraceAttribute]]: + return getattr(self, 'trace-attributes') + + @property + def trace_context(self) -> Optional[TraceContext]: + return getattr(self, 'trace-context') + + @property + def conversation_manager(self) -> Optional[ConversationManagerConfig]: + return getattr(self, 'conversation-manager') + + @property + def structured_output_schema(self) -> Optional[str]: + return getattr(self, 'structured-output-schema') + + @property + def app_state(self) -> Optional[str]: + return getattr(self, 'app-state') + + @property + def model_state(self) -> Optional[str]: + return getattr(self, 'model-state') + + def __repr__(self) -> str: + return f'AgentConfig(model={getattr(self, 'model')!r}, model_params={getattr(self, 'model-params')!r}, messages={getattr(self, 'messages')!r}, system_prompt={getattr(self, 'system-prompt')!r}, tools={getattr(self, 'tools')!r}, agent_tools={getattr(self, 'agent-tools')!r}, vended_tools={getattr(self, 'vended-tools')!r}, vended_plugins={getattr(self, 'vended-plugins')!r}, mcp_clients={getattr(self, 'mcp-clients')!r}, identity={getattr(self, 'identity')!r}, tool_executor={getattr(self, 'tool-executor')!r}, display_output={getattr(self, 'display-output')!r}, trace_attributes={getattr(self, 'trace-attributes')!r}, trace_context={getattr(self, 'trace-context')!r}, session={getattr(self, 'session')!r}, conversation_manager={getattr(self, 'conversation-manager')!r}, retry={getattr(self, 'retry')!r}, structured_output_schema={getattr(self, 'structured-output-schema')!r}, app_state={getattr(self, 'app-state')!r}, model_state={getattr(self, 'model-state')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AgentConfig): + return NotImplemented + return getattr(self, 'model') == getattr(other, 'model') and getattr(self, 'model-params') == getattr(other, 'model-params') and getattr(self, 'messages') == getattr(other, 'messages') and getattr(self, 'system-prompt') == getattr(other, 'system-prompt') and getattr(self, 'tools') == getattr(other, 'tools') and getattr(self, 'agent-tools') == getattr(other, 'agent-tools') and getattr(self, 'vended-tools') == getattr(other, 'vended-tools') and getattr(self, 'vended-plugins') == getattr(other, 'vended-plugins') and getattr(self, 'mcp-clients') == getattr(other, 'mcp-clients') and getattr(self, 'identity') == getattr(other, 'identity') and getattr(self, 'tool-executor') == getattr(other, 'tool-executor') and getattr(self, 'display-output') == getattr(other, 'display-output') and getattr(self, 'trace-attributes') == getattr(other, 'trace-attributes') and getattr(self, 'trace-context') == getattr(other, 'trace-context') and getattr(self, 'session') == getattr(other, 'session') and getattr(self, 'conversation-manager') == getattr(other, 'conversation-manager') and getattr(self, 'retry') == getattr(other, 'retry') and getattr(self, 'structured-output-schema') == getattr(other, 'structured-output-schema') and getattr(self, 'app-state') == getattr(other, 'app-state') and getattr(self, 'model-state') == getattr(other, 'model-state') + + def __hash__(self) -> int: + return id(self) + + +class InvokeArgs: + """Arguments for `agent.generate`.""" + def __init__( + self, + *, + input: PromptInput, + tools: Optional[list[ToolSpec]], + tool_choice: Optional[ToolChoice], + structured_output_schema: Optional[str], + ) -> None: + setattr(self, 'input', input) + setattr(self, 'tools', tools) + setattr(self, 'tool-choice', tool_choice) + setattr(self, 'structured-output-schema', structured_output_schema) + + @property + def tool_choice(self) -> Optional[ToolChoice]: + return getattr(self, 'tool-choice') + + @property + def structured_output_schema(self) -> Optional[str]: + return getattr(self, 'structured-output-schema') + + def __repr__(self) -> str: + return f'InvokeArgs(input={getattr(self, 'input')!r}, tools={getattr(self, 'tools')!r}, tool_choice={getattr(self, 'tool-choice')!r}, structured_output_schema={getattr(self, 'structured-output-schema')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, InvokeArgs): + return NotImplemented + return getattr(self, 'input') == getattr(other, 'input') and getattr(self, 'tools') == getattr(other, 'tools') and getattr(self, 'tool-choice') == getattr(other, 'tool-choice') and getattr(self, 'structured-output-schema') == getattr(other, 'structured-output-schema') + + def __hash__(self) -> int: + return id(self) + + +class RespondArgs: + """Payload supplied when resuming from a human-in-the-loop interrupt.""" + def __init__( + self, + *, + interrupt_id: str, + response: str, + ) -> None: + setattr(self, 'interrupt-id', interrupt_id) + setattr(self, 'response', response) + + @property + def interrupt_id(self) -> str: + return getattr(self, 'interrupt-id') + + def __repr__(self) -> str: + return f'RespondArgs(interrupt_id={getattr(self, 'interrupt-id')!r}, response={getattr(self, 'response')!r})' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, RespondArgs): + return NotImplemented + return getattr(self, 'interrupt-id') == getattr(other, 'interrupt-id') and getattr(self, 'response') == getattr(other, 'response') + + def __hash__(self) -> int: + return id(self) + + +class AgentError: + """Why an agent-resource call failed.""" + pass + +class AgentError_NoSessionConfigured(AgentError, _WitVariantCase): + """The agent was constructed without a session config.""" + tag = 'no-session-configured' + +class AgentError_Storage(AgentError, _WitVariantCase): + """The storage backend rejected the operation.""" + tag = 'storage' + +class AgentError_InvalidInput(AgentError, _WitVariantCase): + """Supplied payload did not match the expected shape.""" + tag = 'invalid-input' + +class AgentError_UnknownInterrupt(AgentError, _WitVariantCase): + """Supplied `interrupt-id` does not match any live interrupt.""" + tag = 'unknown-interrupt' + +class AgentError_Internal(AgentError, _WitVariantCase): + """Catch-all for internal failures.""" + tag = 'internal' + +_AgentError_CASES: dict[str, type] = { + 'no-session-configured': AgentError_NoSessionConfigured, + 'storage': AgentError_Storage, + 'invalid-input': AgentError_InvalidInput, + 'unknown-interrupt': AgentError_UnknownInterrupt, + 'internal': AgentError_Internal, +} + +def _AgentError_lift(raw: _WitVariant) -> AgentError: + cls = _AgentError_CASES.get(raw.tag) + if cls is None: + raise ValueError(f'unknown AgentError arm: {raw.tag!r}') + return cls(raw.payload) +AgentError.lift = staticmethod(_AgentError_lift) # type: ignore[attr-defined] + +class Agent: + """An agent instance. Persistent across `generate` calls.""" + # Wraps a wasmtime-py ResourceAny / ResourceHost handle. + # The runtime sets ._handle to the underlying resource and + # ._invoke to a callable that dispatches a method by WIT name. + + def __init__(self, handle: Any, invoke: Any = None) -> None: + self._handle = handle + self._invoke = invoke + + @staticmethod + def new(config: AgentConfig, *, invoke: Any) -> 'Agent': + return Agent(invoke('[constructor]agent', (config,)), invoke) + + def generate(self, args: InvokeArgs) -> Any: + return self._invoke('[method]agent.generate', (self._handle, args,)) + + def get_messages(self) -> list[Message]: + return self._invoke('[method]agent.get-messages', (self._handle,)) + + def set_messages(self, messages: list[Message]) -> Any: + return self._invoke('[method]agent.set-messages', (self._handle, messages,)) + + def get_app_state(self) -> str: + return self._invoke('[method]agent.get-app-state', (self._handle,)) + + def set_app_state(self, json: str) -> Any: + return self._invoke('[method]agent.set-app-state', (self._handle, json,)) + + def get_model_state(self) -> str: + return self._invoke('[method]agent.get-model-state', (self._handle,)) + + def set_model_state(self, json: str) -> Any: + return self._invoke('[method]agent.set-model-state', (self._handle, json,)) + + def get_traces(self) -> list[AgentTrace]: + return self._invoke('[method]agent.get-traces', (self._handle,)) + + def get_metrics(self) -> AgentMetrics: + return self._invoke('[method]agent.get-metrics', (self._handle,)) + + def save_session(self) -> Any: + return self._invoke('[method]agent.save-session', (self._handle,)) + + def list_snapshots(self) -> Any: + return self._invoke('[method]agent.list-snapshots', (self._handle,)) + + def delete_session(self) -> Any: + return self._invoke('[method]agent.delete-session', (self._handle,)) + + +class EventStream: + """Pull-based stream of agent events; sync-WIT placeholder for `stream`.""" + # Wraps a wasmtime-py ResourceAny / ResourceHost handle. + # The runtime sets ._handle to the underlying resource and + # ._invoke to a callable that dispatches a method by WIT name. + + def __init__(self, handle: Any, invoke: Any = None) -> None: + self._handle = handle + self._invoke = invoke + + def read(self) -> Optional[StreamEvent]: + return self._invoke('[method]event-stream.read', (self._handle,)) + + +class ResponseStream: + """Handle to an in-flight `generate` invocation.""" + # Wraps a wasmtime-py ResourceAny / ResourceHost handle. + # The runtime sets ._handle to the underlying resource and + # ._invoke to a callable that dispatches a method by WIT name. + + def __init__(self, handle: Any, invoke: Any = None) -> None: + self._handle = handle + self._invoke = invoke + + def events(self) -> Any: + return self._invoke('[method]response-stream.events', (self._handle,)) + + def respond(self, args: RespondArgs) -> Any: + return self._invoke('[method]response-stream.respond', (self._handle, args,)) + + def cancel(self) -> None: + return self._invoke('[method]response-stream.cancel', (self._handle,)) + diff --git a/strands-py-wasm/src/strands/_runtime.py b/strands-py-wasm/src/strands/_runtime.py new file mode 100644 index 0000000000..bd9580dacc --- /dev/null +++ b/strands-py-wasm/src/strands/_runtime.py @@ -0,0 +1,490 @@ +"""Wasmtime-backed runtime adapter for :class:`strands.Agent`. + +Implementation detail of the public SDK surface in ``strands/__init__.py``. +This module owns: + +* the process-wide ``Engine`` + ``Component`` (one wasm load per process) +* the per-Agent ``Store`` + ``Linker`` + ``Instance`` lifecycle +* host-import callbacks (``tool-provider.call-tool``, ``host-log``) +* the host-side ``tool-event-stream`` resource the wasm component drains +* the async :class:`EventStream` wrapper that turns the sync wasm + ``read()`` call into an ``AsyncIterator`` for SDK callers + +External SDK code talks to ``_AgentRuntime`` and ``EventStream``; everything +else (WASI setup, marshaling, resource bookkeeping) is private. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import threading +from collections import deque +from collections.abc import Iterable +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from wasmtime import Config, Engine, Store, WasiConfig +from wasmtime.component import ( + Component, + Linker, + ResourceAny, + ResourceHost, + ResourceType, + Variant, +) + +from . import _generated as _t + +if TYPE_CHECKING: + from . import Agent + + +_GUEST_LOGGER = logging.getLogger("strands.guest") +logger = logging.getLogger(__name__) + + +_singleton_lock = threading.RLock() +_engine: Engine | None = None +_component: Component | None = None + + +def _wasm_path() -> Path: + env = os.environ.get("STRANDS_AGENT_WASM") + if env: + return Path(env) + + bundled = Path(__file__).resolve().parent / "strands-agent.wasm" + if bundled.exists(): + return bundled + + sdk_root = Path(__file__).resolve().parents[3] + dev = sdk_root / "strands-wasm" / "dist" / "strands-agent.wasm" + if dev.exists(): + return dev + + raise FileNotFoundError( + "Could not locate strands-agent.wasm. Set STRANDS_AGENT_WASM, install the " + "bundled wheel, or build strands-wasm/dist/strands-agent.wasm." + ) + + +def _get_engine() -> Engine: + global _engine + with _singleton_lock: + if _engine is None: + cfg = Config() + cfg.wasm_component_model = True + _engine = Engine(cfg) + return _engine + + +def _get_component() -> Component: + global _component + with _singleton_lock: + if _component is None: + _component = Component.from_file(_get_engine(), str(_wasm_path())) + return _component + + +_TOOL_STREAM_RESOURCE_TYPE_TAG = 0x1 +_tool_stream_resource_type: ResourceType | None = None + + +class _HostToolEventStream: + """Host-side tool-event-stream backing a single ``call-tool`` invocation. + + Holds a queue of WIT ``tool-stream-event`` Variants. ``read()`` returns + them one at a time, then ``None`` after the terminal ``complete`` / + ``error`` event has been delivered. + """ + + def __init__(self) -> None: + self._events: deque[Variant] = deque() + self._closed = False + + def push(self, event: Variant) -> None: + self._events.append(event) + + def close(self) -> None: + self._closed = True + + def read(self) -> Variant | None: + if self._events: + return self._events.popleft() + return None + + +class _ToolStreamRegistry: + """Per-runtime registry for live host-side tool-event-stream reps. + + Scoped to a single :class:`_AgentRuntime` so streams are bounded by the + runtime's lifetime: when the runtime is GC'd, its registry goes with it + and any stragglers (e.g. from an aborted invocation) are released. + """ + + def __init__(self) -> None: + self._reps: dict[int, _HostToolEventStream] = {} + self._next_rep = 1 + self._lock = threading.Lock() + + def register(self, stream: _HostToolEventStream) -> int: + with self._lock: + rep = self._next_rep + self._next_rep += 1 + self._reps[rep] = stream + return rep + + def lookup(self, rep: int) -> _HostToolEventStream: + with self._lock: + return self._reps[rep] + + def drop(self, rep: int) -> None: + with self._lock: + self._reps.pop(rep, None) + + +def _ensure_tool_stream_type() -> ResourceType: + # ResourceType identity must be process-stable so wasmtime-py recognizes + # the same WIT resource across runtimes. + global _tool_stream_resource_type + if _tool_stream_resource_type is None: + _tool_stream_resource_type = ResourceType.host(_TOOL_STREAM_RESOURCE_TYPE_TAG) + return _tool_stream_resource_type + + +def _make_tool_call_handler(agent: Agent, registry: _ToolStreamRegistry): + def call_tool(store: Any, args: Any) -> ResourceHost: + name = getattr(args, "name", "") + raw_input = getattr(args, "input", "") + stream = _HostToolEventStream() + try: + tool = agent._lookup_tool(name) + content_list = tool.invoke(raw_input) + except Exception as exc: # noqa: BLE001 surface any tool exception as a tool-error + logger.exception("tool %r raised; surfacing as tool-error to guest", name) + stream.push(_t.ToolStreamEvent.error(_t.ToolError.execution_failed(str(exc)))) + stream.close() + else: + # content_list items are already ToolResultContent wire variants. + stream.push(_t.ToolStreamEvent.complete(content_list)) + stream.close() + + # Register, then hand ownership to the guest. On failure, drop the rep + # ourselves since the guest never received it and won't drop it. + rep = registry.register(stream) + try: + return ResourceHost.own(rep, _TOOL_STREAM_RESOURCE_TYPE_TAG) + except BaseException: + registry.drop(rep) + raise + + return call_tool + + +def _host_log(_store: Any, entry: Any) -> None: + level = getattr(entry, "level", "info") + if not isinstance(level, str): + level = str(level) + message = getattr(entry, "message", "") + context_raw = getattr(entry, "context", None) + extra = {} + if context_raw: + try: + extra = {"context": json.loads(context_raw)} + except Exception: + extra = {"context": context_raw} + py_level = { + "trace": logging.DEBUG, + "debug": logging.DEBUG, + "info": logging.INFO, + "warn": logging.WARNING, + "error": logging.ERROR, + }.get(level.lower(), logging.INFO) + _GUEST_LOGGER.log(py_level, message, extra={"strands": extra} if extra else None) + + +def _make_tool_event_stream_read(registry: _ToolStreamRegistry): + def _tool_event_stream_read(store: Any, handle: ResourceAny) -> Variant | None: + host = handle.to_host(store) + rep = host.rep + stream = registry.lookup(rep) + return stream.read() + return _tool_event_stream_read + + +def _trap(name: str): + def _f(*_a, **_k): + raise RuntimeError(f"host import not implemented: {name}") + return _f + + +_MODEL_EVENT_STREAM_TYPE_TAG = 0x2 + + +def _register_imports(linker: Linker, agent: Agent, registry: _ToolStreamRegistry) -> None: + tool_stream_type = _ensure_tool_stream_type() + # model-provider's host-side stream type. Only reached on custom providers. + model_event_stream_type = ResourceType.host(_MODEL_EVENT_STREAM_TYPE_TAG) + + with linker.root() as root: + with root.add_instance("strands:agent/host-log@0.1.0") as ns: + ns.add_func("log", _host_log) + + with root.add_instance("strands:agent/tools@0.1.0") as ns: + ns.add_resource("tool-event-stream", tool_stream_type, lambda _store, rep: registry.drop(rep)) + ns.add_func("[method]tool-event-stream.read", _make_tool_event_stream_read(registry)) + + with root.add_instance("strands:agent/tool-provider@0.1.0") as ns: + ns.add_func("call-tool", _make_tool_call_handler(agent, registry)) + + # Stubs for the imports the basic Agent.invoke flow never reaches. + with root.add_instance("strands:agent/model-provider@0.1.0") as ns: + ns.add_resource("model-event-stream", model_event_stream_type, lambda _s, _r: None) + ns.add_func("[method]model-event-stream.read", _trap("model-event-stream.read")) + ns.add_func("start-stream", _trap("model-provider.start-stream")) + ns.add_func("count-tokens", _trap("model-provider.count-tokens")) + + with root.add_instance("strands:agent/snapshot-storage@0.1.0") as ns: + for fname in ("save-snapshot", "load-snapshot", "list-snapshot-ids", + "delete-session", "load-manifest", "save-manifest"): + ns.add_func(fname, _trap(f"snapshot-storage.{fname}")) + + with root.add_instance("strands:agent/snapshot-trigger-handler@0.1.0") as ns: + ns.add_func("should-snapshot", _trap("snapshot-trigger-handler.should-snapshot")) + + with root.add_instance("strands:agent/edge-handler-registry@0.1.0") as ns: + ns.add_func("evaluate", _trap("edge-handler-registry.evaluate")) + + with root.add_instance("strands:agent/elicitation-handler@0.1.0") as ns: + ns.add_func("elicit", _trap("elicitation-handler.elicit")) + + +# --- Store + Linker ----------------------------------------------------- + +def _make_store_and_linker( + agent: Agent, registry: _ToolStreamRegistry +) -> tuple[Store, Linker]: + engine = _get_engine() + store = Store(engine) + wasi = WasiConfig() + wasi.inherit_stdout() + wasi.inherit_stderr() + wasi.inherit_env() + store.set_wasi(wasi) + store.set_wasi_http() + + linker = Linker(engine) + linker.allow_shadowing = True + linker.add_wasip2() + linker.add_wasi_http_async() + _register_imports(linker, agent, registry) + return store, linker + + +class EventStream: + """Async iterator over guest-emitted :class:`StreamEvent` values. + + Wraps the wasm-side ``[method]event-stream.read`` call. Each ``__anext__`` + runs ``read`` on a worker thread so the asyncio loop stays responsive + while the guest blocks waiting for the next event. + """ + + def __init__(self, runtime: _AgentRuntime, handle: ResourceAny) -> None: + self._runtime = runtime + self._handle: ResourceAny | None = handle + self._closed = False + + def __aiter__(self) -> EventStream: + return self + + async def __anext__(self) -> Any: + if self._closed or self._handle is None: + raise StopAsyncIteration + raw = await self._runtime.event_stream_read(self._handle) + if raw is None: + self._closed = True + handle = self._handle + self._handle = None + handle.drop(self._runtime._store) + raise StopAsyncIteration + return _t.StreamEvent.lift(raw) + + +class _AgentRuntime: + """Lazy wrapper around the wasm Agent resource for one ``strands.Agent``. + + Construction is split across :meth:`__init__` (sync, cheap) and + :meth:`async_init` (drives the wasm constructor through ``call_async``). + Callers must await ``async_init`` before invoking any other method. + """ + + def __init__(self, agent: Agent) -> None: + self._agent = agent + self._lock = threading.Lock() + self._tool_streams = _ToolStreamRegistry() + self._store, self._linker = _make_store_and_linker(agent, self._tool_streams) + self._instance = self._linker.instantiate(self._store, _get_component()) + self._funcs = _ApiFuncs(self._store, self._instance) + self._handle: ResourceAny | None = None + self._current_response: ResourceAny | None = None + + def init(self) -> None: + if self._handle is not None: + return + # The bindgen AgentConfig is already wire-shape; pass through. + self._handle = self._funcs.constructor(self._store, self._agent._config) + self._funcs.constructor.post_return(self._store) + + async def async_init(self) -> None: + # Async hook so callers don't need to know construction is sync. + self.init() + + async def generate(self, args: _t.InvokeArgs) -> EventStream: + # Run sync wasm calls in a worker thread so the asyncio loop stays free. + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, self._generate_blocking, args) + + def _generate_blocking(self, args: _t.InvokeArgs) -> EventStream: + with self._lock: + response_handle: ResourceAny = self._funcs.generate(self._store, self._handle, args) + self._funcs.generate.post_return(self._store) + self._current_response = response_handle + stream_handle: ResourceAny = self._funcs.events(self._store, response_handle) + self._funcs.events.post_return(self._store) + return EventStream(self, stream_handle) + + def cancel(self) -> None: + with self._lock: + handle = self._current_response + if handle is None: + return + self._funcs.cancel(self._store, handle) + self._funcs.cancel.post_return(self._store) + + async def respond(self, args: _t.RespondArgs) -> None: + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self._respond_blocking, args) + + def _respond_blocking(self, args: _t.RespondArgs) -> None: + with self._lock: + handle = self._current_response + if handle is None: + from . import StrandsError + + raise StrandsError("respond() called with no in-flight invocation") + res = self._funcs.respond(self._store, handle, args) + self._funcs.respond.post_return(self._store) + _raise_on_err(res) + + async def get_messages(self) -> list[_t.Message]: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, self._get_messages_blocking) + + def _get_messages_blocking(self) -> list[Any]: + with self._lock: + raw = self._funcs.get_messages(self._store, self._handle) + self._funcs.get_messages.post_return(self._store) + return raw # wasmtime-py records expose the same kebab-case attrs as bindgen Message + + async def set_messages(self, messages: Iterable[Any]) -> None: + # bindgen Message instances are already wire-shape; pass through. + wit = list(messages) + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self._set_messages_blocking, wit) + + def _set_messages_blocking(self, wit: list[Any]) -> None: + with self._lock: + res = self._funcs.set_messages(self._store, self._handle, wit) + self._funcs.set_messages.post_return(self._store) + _raise_on_err(res) + + async def get_app_state(self) -> dict[str, Any]: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, self._get_app_state_blocking) + + def _get_app_state_blocking(self) -> dict[str, Any]: + with self._lock: + raw = self._funcs.get_app_state(self._store, self._handle) + self._funcs.get_app_state.post_return(self._store) + return json.loads(raw) if raw else {} + + async def set_app_state(self, state: dict[str, Any]) -> None: + payload = json.dumps(state) + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self._set_app_state_blocking, payload) + + def _set_app_state_blocking(self, payload: str) -> None: + with self._lock: + res = self._funcs.set_app_state(self._store, self._handle, payload) + self._funcs.set_app_state.post_return(self._store) + _raise_on_err(res) + + async def get_model_state(self) -> dict[str, Any]: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, self._get_model_state_blocking) + + def _get_model_state_blocking(self) -> dict[str, Any]: + with self._lock: + raw = self._funcs.get_model_state(self._store, self._handle) + self._funcs.get_model_state.post_return(self._store) + return json.loads(raw) if raw else {} + + async def set_model_state(self, state: dict[str, Any]) -> None: + payload = json.dumps(state) + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self._set_model_state_blocking, payload) + + def _set_model_state_blocking(self, payload: str) -> None: + with self._lock: + res = self._funcs.set_model_state(self._store, self._handle, payload) + self._funcs.set_model_state.post_return(self._store) + _raise_on_err(res) + + async def event_stream_read(self, handle: ResourceAny) -> Variant | None: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, self._event_stream_read_blocking, handle) + + def _event_stream_read_blocking(self, handle: ResourceAny) -> Variant | None: + with self._lock: + raw = self._funcs.event_stream_read(self._store, handle) + self._funcs.event_stream_read.post_return(self._store) + return raw + + +def _raise_on_err(res: Any) -> None: + if isinstance(res, Variant) and res.tag == "err": + from . import StrandsError + + # res.payload is a wire AgentError Variant; surface it raw for now. + raise StrandsError(f"agent call failed: {res.payload!r}") + + +class _ApiFuncs: + """Caches every exported function from ``strands:agent/api@0.1.0``.""" + + def __init__(self, store: Store, instance: Any) -> None: + api = instance.get_export_index(store, "strands:agent/api@0.1.0") + if api is None: + raise RuntimeError("component is missing strands:agent/api@0.1.0 export") + + def f(name: str): + fn = instance.get_func(store, instance.get_export_index(store, name, api)) + if fn is None: + raise RuntimeError(f"missing api export: {name}") + return fn + + self.constructor = f("[constructor]agent") + self.generate = f("[method]agent.generate") + self.get_messages = f("[method]agent.get-messages") + self.set_messages = f("[method]agent.set-messages") + self.get_app_state = f("[method]agent.get-app-state") + self.set_app_state = f("[method]agent.set-app-state") + self.get_model_state = f("[method]agent.get-model-state") + self.set_model_state = f("[method]agent.set-model-state") + self.events = f("[method]response-stream.events") + self.respond = f("[method]response-stream.respond") + self.cancel = f("[method]response-stream.cancel") + self.event_stream_read = f("[method]event-stream.read") diff --git a/strands-py-wasm/src/strands/py.typed b/strands-py-wasm/src/strands/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/strands-ts/eslint.config.js b/strands-ts/eslint.config.js new file mode 100644 index 0000000000..53a935c47d --- /dev/null +++ b/strands-ts/eslint.config.js @@ -0,0 +1,155 @@ +import eslint from '@eslint/js' +import tseslint from '@typescript-eslint/eslint-plugin' +import tsparser from '@typescript-eslint/parser' +import tsdoc from 'eslint-plugin-tsdoc' + +export default [ + eslint.configs.recommended, + { + rules: { + // Disabled: TypeScript compiler catches all redeclaration cases and + // understands value/type namespace merging (e.g., const + type with + // same name). See https://typescript-eslint.io/rules/no-redeclare/ + 'no-redeclare': 'off', + }, + }, + // Apply SDK rules to src files + sdkRules({ + files: ['src/**/*.ts'], + tsconfig: './src/tsconfig.json', + }), + // Prevent non-vended-tools from importing vended-tools + noVendedToolsImports({ + files: ['src/**/*.ts'], + ignores: ['src/vended-tools/**/*.ts'], + }), + // Then unit-test rules to UTs + unitTestRules({ + files: ['src/**/__tests__/**/*.ts'], + tsconfig: './src/tsconfig.json', + }), + // Apply UT rules to the integ tests + unitTestRules({ + files: ['test/integ/**/*.ts'], + tsconfig: './test/integ/tsconfig.json', + }), + // Then stricter integ test rules + integTestRules({ + files: ['test/integ/**/*.ts'], + tsconfig: './test/integ/tsconfig.json', + }), +] + +function sdkRules(options) { + return { + files: options.files, + languageOptions: { + parser: tsparser, + parserOptions: { + ecmaVersion: 2022, + sourceType: 'module', + project: options.tsconfig, + }, + globals: { + console: 'readonly', + process: 'readonly', + setTimeout: 'readonly', + clearTimeout: 'readonly', + atob: 'readonly', + btoa: 'readonly', + crypto: 'readonly', + }, + }, + plugins: { + '@typescript-eslint': tseslint, + tsdoc: tsdoc, + }, + rules: { + ...tseslint.configs.recommended.rules, + '@typescript-eslint/no-explicit-any': 'error', + '@typescript-eslint/no-unused-vars': ['error', { argsIgnorePattern: '^_', varsIgnorePattern: '^_' }], + '@typescript-eslint/explicit-function-return-type': 'error', + '@typescript-eslint/explicit-module-boundary-types': 'error', + 'tsdoc/syntax': 'error', + }, + } +} + +function unitTestRules(options) { + return { + files: options.files, + languageOptions: { + parser: tsparser, + parserOptions: { + ecmaVersion: 2022, + sourceType: 'module', + project: options.tsconfig, + }, + globals: { + process: 'readonly', + console: 'readonly', + window: 'readonly', + document: 'readonly', + navigator: 'readonly', + setTimeout: 'readonly', + clearTimeout: 'readonly', + }, + }, + plugins: { + '@typescript-eslint': tseslint, + }, + rules: { + ...tseslint.configs.recommended.rules, + '@typescript-eslint/no-explicit-any': 'off', + '@typescript-eslint/no-unused-vars': ['error', { argsIgnorePattern: '^_', varsIgnorePattern: '^_' }], + '@typescript-eslint/explicit-function-return-type': 'off', + quotes: ['error', 'single', { avoidEscape: true }], + }, + } +} + +function integTestRules(options) { + return { + files: options.files, + languageOptions: { + parserOptions: { + project: options.tsconfig, + }, + }, + rules: { + 'no-restricted-imports': [ + 'error', + { + patterns: [ + { + group: ['../src', '../src/**'], + message: + 'Integration tests should use $/sdk/* path aliases instead of ../src. Test fixtures can import from $/sdk/*.', + }, + ], + }, + ], + }, + } +} + +function noVendedToolsImports(options) { + return { + files: options.files, + ignores: options.ignores, + rules: { + 'no-restricted-imports': [ + 'error', + { + patterns: [ + { + group: ['**/vended-tools', '**/vended-tools/**'], + message: + 'Core SDK files should not import from vended-tools. Vended tools are optional and independently importable.', + }, + ], + }, + ], + }, + } +} diff --git a/strands-ts/examples/README.md b/strands-ts/examples/README.md new file mode 100644 index 0000000000..c5934e300c --- /dev/null +++ b/strands-ts/examples/README.md @@ -0,0 +1,29 @@ +# Examples + +Sample applications demonstrating Strands Agents TypeScript SDK features. + +## Prerequisites + +- Node.js 20+ +- AWS credentials configured (for the default Bedrock model provider) + +## Running an Example + +Each example is a standalone project. From any example directory: + +```bash +npm install +npm start +``` + +## Available Examples + +| Example | Description | +|---------|-------------| +| [first-agent](./first-agent/) | Basic agent usage with tools, invoke, and streaming patterns | +| [graph](./graph/) | Graph multi-agent orchestration (linear, fan-out, streaming) | +| [swarm](./swarm/) | Swarm multi-agent orchestration (agent-driven handoffs) | +| [mcp](./mcp/) | Model Context Protocol integration with external tool servers | +| [agents-as-tools](./agents-as-tools/) | Agents as tools pattern (orchestrator delegates to specialized tool agents) | +| [browser-agent](./browser-agent/) | Browser-based agent with DOM manipulation canvas (OpenAI, Anthropic, Bedrock) | +| [telemetry](./telemetry/) | OpenTelemetry tracing with Jaeger (requires Docker, see its [README](./telemetry/README.md)) | diff --git a/strands-ts/examples/agents-as-tools/.gitignore b/strands-ts/examples/agents-as-tools/.gitignore new file mode 100644 index 0000000000..91a3983f34 --- /dev/null +++ b/strands-ts/examples/agents-as-tools/.gitignore @@ -0,0 +1,3 @@ +dist +node_modules +package-lock.json diff --git a/strands-ts/examples/agents-as-tools/package.json b/strands-ts/examples/agents-as-tools/package.json new file mode 100644 index 0000000000..88b70eaebf --- /dev/null +++ b/strands-ts/examples/agents-as-tools/package.json @@ -0,0 +1,22 @@ +{ + "name": "agents-as-tools-example", + "private": true, + "main": "dist/index.js", + "type": "module", + "scripts": { + "prepare": "npm ci --prefix ../../..", + "clean": "rm -rf dist node_modules package-lock.json", + "build": "tsc", + "start": "tsc && node dist/index.js" + }, + "workspaces": [ + "../../" + ], + "dependencies": { + "@strands-agents/sdk": "*" + }, + "devDependencies": { + "@types/node": "^20.0.0", + "typescript": "^5.5.0" + } +} diff --git a/strands-ts/examples/agents-as-tools/src/index.ts b/strands-ts/examples/agents-as-tools/src/index.ts new file mode 100644 index 0000000000..840dec1703 --- /dev/null +++ b/strands-ts/examples/agents-as-tools/src/index.ts @@ -0,0 +1,110 @@ +import { Agent, AgentResult, BedrockModel, tool } from '@strands-agents/sdk' +import { z } from 'zod' + +/** + * Teacher's Assistant — Agents as Tools + * + * An orchestrator agent routes student queries to specialized tool agents, + * each focused on a single subject area. This mirrors the Python + * "Teacher's Assistant" example using the agents-as-tools pattern. + */ + +function extractText(result: AgentResult): string { + return result.lastMessage.content.map((b) => ('text' in b ? b.text : '')).join('') +} + +const model = new BedrockModel({ maxTokens: 1024 }) + +// Specialized tool agents + +const mathAssistant = tool({ + name: 'math_assistant', + description: 'Handle mathematical calculations, problems, and concepts.', + inputSchema: z.object({ + query: z.string().describe('A math question or problem'), + }), + callback: async (input) => { + const agent = new Agent({ + model, + printer: false, + systemPrompt: `You are a math tutor. Solve problems step-by-step and explain your reasoning clearly.`, + }) + const result = await agent.invoke(input.query) + return extractText(result) + }, +}) + +const englishAssistant = tool({ + name: 'english_assistant', + description: 'Help with writing, grammar, literature, and composition.', + inputSchema: z.object({ + query: z.string().describe('An English or writing question'), + }), + callback: async (input) => { + const agent = new Agent({ + model, + printer: false, + systemPrompt: `You are an English tutor. Help with grammar, writing, literature analysis, and composition.`, + }) + const result = await agent.invoke(input.query) + return extractText(result) + }, +}) + +const computerScienceAssistant = tool({ + name: 'computer_science_assistant', + description: 'Answer questions about programming, algorithms, and data structures.', + inputSchema: z.object({ + query: z.string().describe('A computer science or programming question'), + }), + callback: async (input) => { + const agent = new Agent({ + model, + printer: false, + systemPrompt: `You are a computer science tutor. Explain programming concepts, algorithms, and data structures clearly with examples.`, + }) + const result = await agent.invoke(input.query) + return extractText(result) + }, +}) + +const generalAssistant = tool({ + name: 'general_assistant', + description: 'Handle general knowledge questions outside specialized subject areas.', + inputSchema: z.object({ + query: z.string().describe('A general knowledge question'), + }), + callback: async (input) => { + const agent = new Agent({ + model, + printer: false, + systemPrompt: `You are a helpful general assistant. Answer questions clearly and concisely.`, + }) + const result = await agent.invoke(input.query) + return extractText(result) + }, +}) + +// Orchestrator agent + +const teacher = new Agent({ + model, + systemPrompt: `You are TeachAssist, an educational orchestrator that routes student queries to specialists: +- Math questions → math_assistant +- Writing, grammar, literature → english_assistant +- Programming, algorithms, CS → computer_science_assistant +- Everything else → general_assistant + +Always select the most appropriate tool based on the student's query.`, + tools: [mathAssistant, englishAssistant, computerScienceAssistant, generalAssistant], +}) + +async function main(): Promise { + console.log("=== Teacher's Assistant — Agents as Tools ===\n") + + const response = await teacher.invoke('What is the time complexity of merge sort and why?') + console.log('\n=== Final Response ===') + console.log(extractText(response)) +} + +await main().catch(console.error) diff --git a/strands-ts/examples/agents-as-tools/tsconfig.json b/strands-ts/examples/agents-as-tools/tsconfig.json new file mode 100644 index 0000000000..0d30dfb862 --- /dev/null +++ b/strands-ts/examples/agents-as-tools/tsconfig.json @@ -0,0 +1,19 @@ +{ + "compilerOptions": { + "target": "ES2022", + "lib": ["ES2022"], + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": true, + "skipLibCheck": true, + "module": "NodeNext", + "moduleResolution": "NodeNext", + "outDir": "./dist", + "rootDir": "./src", + "declaration": true, + "declarationMap": true, + "sourceMap": true + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "dist", "tests*"] +} diff --git a/strands-ts/examples/browser-agent/.gitignore b/strands-ts/examples/browser-agent/.gitignore new file mode 100644 index 0000000000..91a3983f34 --- /dev/null +++ b/strands-ts/examples/browser-agent/.gitignore @@ -0,0 +1,3 @@ +dist +node_modules +package-lock.json diff --git a/strands-ts/examples/browser-agent/README.md b/strands-ts/examples/browser-agent/README.md new file mode 100644 index 0000000000..de5385bddd --- /dev/null +++ b/strands-ts/examples/browser-agent/README.md @@ -0,0 +1,28 @@ +# Browser Agent Example + +A browser-based AI agent that can modify DOM elements through natural language commands. Supports OpenAI, Anthropic, and AWS Bedrock. + +**⚠️ WARNING: This example is for demonstration purposes only and should NOT be used in production.** The agent executes LLM-generated HTML, CSS, and JavaScript with minimal filtering. While the canvas is sandboxed in an iframe, this pattern is inherently unsafe for untrusted or production environments. + +## Quick Start + +```bash +# Install dependencies +npm install + +# Start dev server +npm run dev +``` + +Open the URL (usually `http://localhost:5173`), configure your API credentials in settings, and start chatting. + +## How It Works + +This example runs a Strands Agent directly in your browser that you can communicate with through the chat window. The agent has access to a custom tool called `update_canvas` that allows it to modify the canvas element displayed in the view with any combination of HTML, CSS, or JavaScript. + +When you send a message, the agent streams its response in real-time and decides whether to use the canvas tool based on your request. The agent maintains conversation history, so it understands context from previous messages. + +Try asking it: +- "Change the background to blue" +- "Add some cats to the canvas" +- "Add a border and center the text" diff --git a/strands-ts/examples/browser-agent/index.html b/strands-ts/examples/browser-agent/index.html new file mode 100644 index 0000000000..26279bceee --- /dev/null +++ b/strands-ts/examples/browser-agent/index.html @@ -0,0 +1,313 @@ + + + + + + + Strands Browser Agent Example + + + + + +

Browser Agent Example

+ +
+ ⚠️ Warning: This browser agent example is not productionized and may be unsafe. Do not use in production without + proper security measures. +
+ +
+
+ + +
+ +
+
+
Hello! I can modify the canvas on the left. 👈 +
Try asking me "change background to blue" or "make it a circle". +
+
+
+ + + +
+
+
+ +
+
+

Settings

+ + + +
+ + +
+ +
+ + +
+ +
+ + + + + + + + + + + +
+ +
+ + +
+
+
+ + + + + diff --git a/strands-ts/examples/browser-agent/package.json b/strands-ts/examples/browser-agent/package.json new file mode 100644 index 0000000000..cf777434bb --- /dev/null +++ b/strands-ts/examples/browser-agent/package.json @@ -0,0 +1,25 @@ +{ + "name": "browser-agent-example", + "private": true, + "version": "0.0.0", + "type": "module", + "scripts": { + "prepare": "npm ci --prefix ../../..", + "dev": "vite", + "build": "vite build", + "preview": "vite preview" + }, + "dependencies": { + "@anthropic-ai/sdk": "^0.71.2", + "@strands-agents/sdk": "*", + "marked": "^17.0.3", + "openai": "^6.33.0" + }, + "devDependencies": { + "typescript": "^5.5.0", + "vite": "^8.0.9" + }, + "workspaces": [ + "../../" + ] +} diff --git a/strands-ts/examples/browser-agent/src/index.ts b/strands-ts/examples/browser-agent/src/index.ts new file mode 100644 index 0000000000..f7eb2a531f --- /dev/null +++ b/strands-ts/examples/browser-agent/src/index.ts @@ -0,0 +1,229 @@ +import { Agent, BedrockModel } from '@strands-agents/sdk' +import { OpenAIModel } from '@strands-agents/sdk/models/openai' +import { AnthropicModel } from '@strands-agents/sdk/models/anthropic' +import { updateCanvasTool } from './tools' +import { marked } from 'marked' + +marked.use({ async: false }) + +const messagesDiv = document.getElementById('messages')! +const inputForm = document.getElementById('input-area') as HTMLFormElement +const userInput = document.getElementById('user-input') as HTMLInputElement +const sendBtn = document.getElementById('send-btn') as HTMLButtonElement +const clearBtn = document.getElementById('clear-btn') as HTMLButtonElement +const settingsBtn = document.getElementById('settings-btn') as HTMLButtonElement +const settingsModal = document.getElementById('settings-modal')! +const providerSelect = document.getElementById('provider-select') as HTMLSelectElement +const saveSettingsBtn = document.getElementById('save-settings-btn') as HTMLButtonElement +const cancelSettingsBtn = document.getElementById('cancel-settings-btn') as HTMLButtonElement + +const openaiKeyInput = document.getElementById('openai-key') as HTMLInputElement +const anthropicKeyInput = document.getElementById('anthropic-key') as HTMLInputElement +const bedrockRegionInput = document.getElementById('bedrock-region') as HTMLInputElement +const bedrockAccessKeyInput = document.getElementById('bedrock-access-key') as HTMLInputElement +const bedrockSecretKeyInput = document.getElementById('bedrock-secret-key') as HTMLInputElement +const bedrockSessionTokenInput = document.getElementById('bedrock-session-token') as HTMLInputElement +const openaiFields = document.querySelector('.openai-fields') as HTMLElement +const anthropicFields = document.querySelector('.anthropic-fields') as HTMLElement +const bedrockFields = document.querySelector('.bedrock-fields') as HTMLElement + +// In-memory credential storage — not persisted across page refreshes +let credentials: Record = {} +let currentProvider = 'openai' + +const WELCOME_HTML = + '
Hello! I can modify the canvas on the left. 👈
Try asking me "change background to blue" or "make it a circle".
' + +function showToast(message: string): void { + const toast = document.createElement('div') + toast.textContent = message + toast.style.cssText = + 'position:fixed;top:2rem;left:50%;transform:translateX(-50%);background:#1d1d1f;color:white;padding:1rem 2rem;border-radius:8px;box-shadow:0 4px 12px rgba(0,0,0,0.3);z-index:2000;' + document.body.appendChild(toast) + setTimeout(() => toast.remove(), 3000) +} + +function toggleProviderFields(provider: string): void { + openaiFields.classList.toggle('show', provider === 'openai') + anthropicFields.classList.toggle('show', provider === 'anthropic') + bedrockFields.classList.toggle('show', provider === 'bedrock') +} + +function addMessage(role: 'user' | 'agent' | 'tool', text: string): HTMLDivElement { + const div = document.createElement('div') + div.className = `message ${role}` + div.textContent = text + messagesDiv.appendChild(div) + messagesDiv.scrollTop = messagesDiv.scrollHeight + return div +} + +function getModel(): BedrockModel | AnthropicModel | OpenAIModel { + if (currentProvider === 'bedrock') { + return new BedrockModel({ + region: credentials['bedrock_region'] || 'us-west-2', + clientConfig: { + credentials: { + accessKeyId: credentials['bedrock_access_key'], + secretAccessKey: credentials['bedrock_secret_key'], + ...(credentials['bedrock_session_token'] && { + sessionToken: credentials['bedrock_session_token'], + }), + }, + }, + }) + } + + if (currentProvider === 'anthropic') { + return new AnthropicModel({ + apiKey: credentials['anthropic_api_key'], + clientConfig: { + dangerouslyAllowBrowser: true, + }, + }) + } + + return new OpenAIModel({ + api: 'chat', + apiKey: credentials['openai_api_key'], + clientConfig: { + dangerouslyAllowBrowser: true, + }, + }) +} + +async function main(): Promise { + let agent: Agent + + function initializeAgent(): void { + const model = getModel() + agent = new Agent({ + model, + systemPrompt: `You are a creative and helpful browser assistant. +You can modify the html, script, and style of the canvas iframe on the page using the update_canvas tool. +Scripts run in the iframe context with access to document.body. +Always use the tool when the user asks for visual changes. +Be concise in your text responses.`, + tools: [updateCanvasTool], + }) + } + + // Disable input until agent is initialized + userInput.disabled = true + sendBtn.disabled = true + + // Show settings on load so user can enter credentials + settingsModal.classList.add('show') + toggleProviderFields(currentProvider) + + settingsBtn.addEventListener('click', () => { + providerSelect.value = currentProvider + openaiKeyInput.value = credentials['openai_api_key'] || '' + anthropicKeyInput.value = credentials['anthropic_api_key'] || '' + bedrockRegionInput.value = credentials['bedrock_region'] || 'us-west-2' + bedrockAccessKeyInput.value = credentials['bedrock_access_key'] || '' + bedrockSecretKeyInput.value = credentials['bedrock_secret_key'] || '' + bedrockSessionTokenInput.value = credentials['bedrock_session_token'] || '' + toggleProviderFields(currentProvider) + settingsModal.classList.add('show') + }) + + cancelSettingsBtn.addEventListener('click', () => { + settingsModal.classList.remove('show') + }) + + saveSettingsBtn.addEventListener('click', () => { + currentProvider = providerSelect.value + + if (currentProvider === 'openai') { + credentials['openai_api_key'] = openaiKeyInput.value + } else if (currentProvider === 'anthropic') { + credentials['anthropic_api_key'] = anthropicKeyInput.value + } else { + credentials['bedrock_region'] = bedrockRegionInput.value + credentials['bedrock_access_key'] = bedrockAccessKeyInput.value + credentials['bedrock_secret_key'] = bedrockSecretKeyInput.value + credentials['bedrock_session_token'] = bedrockSessionTokenInput.value + } + + settingsModal.classList.remove('show') + + try { + initializeAgent() + userInput.disabled = false + sendBtn.disabled = false + messagesDiv.innerHTML = WELCOME_HTML + showToast('Settings saved!') + } catch (err) { + console.error(`error=<${err}> | failed to initialize agent`) + userInput.disabled = true + sendBtn.disabled = true + showToast('Failed to initialize agent. Check your credentials.') + } + }) + + providerSelect.addEventListener('change', (e) => { + toggleProviderFields((e.target as HTMLSelectElement).value) + }) + + clearBtn.addEventListener('click', () => { + messagesDiv.innerHTML = WELCOME_HTML + if (agent) { + agent.messages = [] + } + }) + + inputForm.addEventListener('submit', async (e) => { + e.preventDefault() + const text = userInput.value.trim() + if (!text) return + + addMessage('user', text) + userInput.value = '' + userInput.disabled = true + sendBtn.disabled = true + + const loader = addMessage('agent', '') + loader.innerHTML = '...' + + try { + let fullText = '' + let messageDiv: HTMLDivElement | null = null + + for await (const event of agent.stream(text)) { + if (loader.parentNode) loader.remove() + if (event.type !== 'modelStreamUpdateEvent') continue + const modelEvent = event.event + + if (modelEvent.type === 'modelContentBlockStartEvent') { + if (modelEvent.start?.type === 'toolUseStart') { + const toolMsg = addMessage('tool', `🛠️ Using tool: ${modelEvent.start.name}...`) + toolMsg.style.fontSize = '0.8em' + toolMsg.style.color = '#666' + } else { + fullText = '' + messageDiv = addMessage('agent', '') + } + } else if (modelEvent.type === 'modelContentBlockDeltaEvent' && modelEvent.delta.type === 'textDelta') { + if (!messageDiv) messageDiv = addMessage('agent', '') + fullText += modelEvent.delta.text + try { + messageDiv.innerHTML = marked.parse(fullText) as string + } catch { + messageDiv.textContent = fullText + } + messagesDiv.scrollTop = messagesDiv.scrollHeight + } + } + } catch (err) { + console.error(err) + addMessage('agent', 'Error: ' + (err as Error).message) + } finally { + userInput.disabled = false + sendBtn.disabled = false + userInput.focus() + } + }) +} + +main() diff --git a/strands-ts/examples/browser-agent/src/tools.ts b/strands-ts/examples/browser-agent/src/tools.ts new file mode 100644 index 0000000000..ebdf7cccf5 --- /dev/null +++ b/strands-ts/examples/browser-agent/src/tools.ts @@ -0,0 +1,48 @@ +import { tool } from '@strands-agents/sdk' +import { z } from 'zod' + +export const updateCanvasTool = tool({ + name: 'update_canvas', + description: 'Update the style and content of the canvas element on the page', + inputSchema: z.object({ + html: z.string().optional().describe('HTML content to set as innerHTML of the canvas body element'), + style: z + .record(z.string(), z.string()) + .optional() + .describe( + 'JSON object containing CSS properties to apply to the canvas body element (e.g. {"backgroundColor": "red", "fontSize": "20px"})' + ), + script: z.string().optional().describe('JavaScript code to execute in the canvas iframe'), + }), + callback: (input): string => { + const canvas = document.getElementById('canvas') as HTMLIFrameElement + if (!canvas || !canvas.contentWindow) { + throw new Error('Canvas iframe not found') + } + + const updates: string[] = [] + const doc = canvas.contentDocument || canvas.contentWindow.document + const body = doc.body + + if (input.html) { + body.innerHTML = input.html + updates.push('html updated') + } + + if (input.style) { + Object.assign(body.style, input.style) + updates.push('style updated') + } + + if (input.script) { + canvas.contentWindow.eval(input.script) + updates.push('script executed') + } + + if (updates.length === 0) { + return 'No changes made.' + } + + return `Canvas updated: ${updates.join(', ')}` + }, +}) diff --git a/strands-ts/examples/browser-agent/tsconfig.json b/strands-ts/examples/browser-agent/tsconfig.json new file mode 100644 index 0000000000..92b519084a --- /dev/null +++ b/strands-ts/examples/browser-agent/tsconfig.json @@ -0,0 +1,11 @@ +{ + "compilerOptions": { + "target": "ES2022", + "lib": ["ES2022", "DOM", "DOM.Iterable"], + "module": "ESNext", + "moduleResolution": "bundler", + "strict": true, + "skipLibCheck": true + }, + "include": ["src"] +} diff --git a/strands-ts/examples/first-agent/.gitignore b/strands-ts/examples/first-agent/.gitignore new file mode 100644 index 0000000000..91a3983f34 --- /dev/null +++ b/strands-ts/examples/first-agent/.gitignore @@ -0,0 +1,3 @@ +dist +node_modules +package-lock.json diff --git a/strands-ts/examples/first-agent/package.json b/strands-ts/examples/first-agent/package.json new file mode 100644 index 0000000000..5827f61056 --- /dev/null +++ b/strands-ts/examples/first-agent/package.json @@ -0,0 +1,22 @@ +{ + "name": "first-agent", + "private": true, + "main": "dist/index.js", + "type": "module", + "scripts": { + "prepare": "npm ci --prefix ../../..", + "clean": "rm -rf dist node_modules package-lock.json", + "build": "tsc", + "start": "tsc && node dist/index.js" + }, + "workspaces": [ + "../../" + ], + "dependencies": { + "@strands-agents/sdk": "*" + }, + "devDependencies": { + "@types/node": "^20.0.0", + "typescript": "^5.5.0" + } +} diff --git a/strands-ts/examples/first-agent/src/index.ts b/strands-ts/examples/first-agent/src/index.ts new file mode 100644 index 0000000000..bb03f8bf72 --- /dev/null +++ b/strands-ts/examples/first-agent/src/index.ts @@ -0,0 +1,91 @@ +import { Agent, BedrockModel, tool } from '@strands-agents/sdk' +import { z } from 'zod' + +const weatherTool = tool({ + name: 'get_weather', + description: 'Get the current weather for a specific location.', + inputSchema: z.object({ + location: z.string().describe('The city and state, e.g., San Francisco, CA'), + }), + callback: (input) => { + const fakeWeatherData = { + temperature: '72°F', + conditions: 'sunny', + } + + return `The weather in ${input.location} is ${fakeWeatherData.temperature} and ${fakeWeatherData.conditions}.` + }, +}) + +/** + * Helper function to demonstrate the simple invoke() pattern. + * This is the recommended approach for most use cases. + * @param title The title of the scenario to be logged. + * @param agent The agent instance to use. + * @param prompt The user prompt to invoke the agent with. + */ +async function runInvoke(title: string, agent: Agent, prompt: string) { + console.log(`--- ${title} ---`) + console.log(`User: ${prompt}`) + + const result = await agent.invoke(prompt) + + console.log(`\n::Invocation complete; stop reason was ${result.stopReason}\n`) +} + +/** + * Helper function to demonstrate the stream() pattern. + * Use this when you need access to intermediate streaming events. + * @param title The title of the scenario to be logged. + * @param agent The agent instance to use. + * @param prompt The user prompt to invoke the agent with. + */ +async function runStreaming(title: string, agent: Agent, prompt: string) { + console.log(`--- ${title} ---`) + console.log(`User: ${prompt}`) + + console.log('Agent response stream:') + for await (const event of agent.stream(prompt)) { + console.log('[Event]', event.type) + } + + console.log('\nStreaming complete.\n') +} + +async function main() { + // 1. Initialize the components + const model = new BedrockModel() + + // 2. Create agents + const defaultAgent = new Agent() + const agentWithoutTools = new Agent({ model }) + const agentWithTools = new Agent({ + systemPrompt: + 'You are a helpful assistant that provides weather information using the get_weather tool. Always Inform the user if you run tools.', + model, + tools: [weatherTool], + }) + + // Demonstrate the simple invoke() pattern (recommended for most use cases) + console.log('=== Simple invoke() pattern ===\n') + await runInvoke('0: Invocation with default agent (no model or tools)', defaultAgent, 'Hello!') + await runInvoke('1: Invocation with a model but no tools', agentWithoutTools, 'Hello!') + await runInvoke( + '2: Invocation that uses a tool', + agentWithTools, + 'What is the weather in Toronto? Use the weather tool.' + ) + + const streamingAgentWithTools = new Agent({ + systemPrompt: 'You are a helpful assistant that provides weather information using the get_weather tool.', + model, + tools: [weatherTool], + printer: false, + }) + + // Demonstrate the stream() pattern (for when you need intermediate events) + console.log('\n=== Streaming pattern (advanced) ===\n') + await runStreaming('3: Streaming invocation with events', streamingAgentWithTools, 'What is the weather in Seattle?') +} + +await main().catch(console.error) diff --git a/strands-ts/examples/first-agent/tsconfig.json b/strands-ts/examples/first-agent/tsconfig.json new file mode 100644 index 0000000000..0d30dfb862 --- /dev/null +++ b/strands-ts/examples/first-agent/tsconfig.json @@ -0,0 +1,19 @@ +{ + "compilerOptions": { + "target": "ES2022", + "lib": ["ES2022"], + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": true, + "skipLibCheck": true, + "module": "NodeNext", + "moduleResolution": "NodeNext", + "outDir": "./dist", + "rootDir": "./src", + "declaration": true, + "declarationMap": true, + "sourceMap": true + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "dist", "tests*"] +} diff --git a/strands-ts/examples/graph/.gitignore b/strands-ts/examples/graph/.gitignore new file mode 100644 index 0000000000..91a3983f34 --- /dev/null +++ b/strands-ts/examples/graph/.gitignore @@ -0,0 +1,3 @@ +dist +node_modules +package-lock.json diff --git a/strands-ts/examples/graph/package.json b/strands-ts/examples/graph/package.json new file mode 100644 index 0000000000..ecaf2a887b --- /dev/null +++ b/strands-ts/examples/graph/package.json @@ -0,0 +1,22 @@ +{ + "name": "graph-example", + "private": true, + "main": "dist/index.js", + "type": "module", + "scripts": { + "prepare": "npm ci --prefix ../../..", + "clean": "rm -rf dist node_modules package-lock.json", + "build": "tsc", + "start": "tsc && node dist/index.js" + }, + "workspaces": [ + "../../" + ], + "dependencies": { + "@strands-agents/sdk": "*" + }, + "devDependencies": { + "@types/node": "^20.0.0", + "typescript": "^5.5.0" + } +} diff --git a/strands-ts/examples/graph/src/index.ts b/strands-ts/examples/graph/src/index.ts new file mode 100644 index 0000000000..9e053a77d0 --- /dev/null +++ b/strands-ts/examples/graph/src/index.ts @@ -0,0 +1,83 @@ +import { Agent, BedrockModel, Graph } from '@strands-agents/sdk' + +async function main() { + const model = new BedrockModel({ maxTokens: 1024 }) + + // Define agents as graph nodes + const researcher = new Agent({ + model, + printer: false, + id: 'researcher', + systemPrompt: 'Research the topic and provide key facts in 2-3 sentences.', + }) + + const writer = new Agent({ + model, + printer: false, + id: 'writer', + systemPrompt: 'Rewrite the research into a polished, concise paragraph.', + }) + + // Linear graph: researcher -> writer + console.log('=== Linear Graph ===\n') + const linearGraph = new Graph({ + nodes: [researcher, writer], + edges: [['researcher', 'writer']], + }) + + const linearResult = await linearGraph.invoke('What is the largest ocean on Earth?') + console.log('Status:', linearResult.status) + console.log('Output:', linearResult.content.find((b) => b.type === 'textBlock')?.text) + + // Fan-out graph: router -> [capitals, oceans] (parallel execution) + console.log('\n=== Fan-Out Graph ===\n') + const router = new Agent({ + model, + printer: false, + id: 'router', + systemPrompt: 'Repeat the user input exactly.', + }) + + const capitals = new Agent({ + model, + printer: false, + id: 'capitals', + systemPrompt: 'Answer with only the capital of France.', + }) + + const oceans = new Agent({ + model, + printer: false, + id: 'oceans', + systemPrompt: 'Answer with only the largest ocean.', + }) + + const fanOutGraph = new Graph({ + nodes: [router, capitals, oceans], + edges: [ + ['router', 'capitals'], + ['router', 'oceans'], + ], + }) + + const fanOutResult = await fanOutGraph.invoke('Go') + console.log('Status:', fanOutResult.status) + console.log('Nodes executed:', fanOutResult.results.map((r) => r.nodeId).join(', ')) + for (const block of fanOutResult.content) { + if (block.type === 'textBlock') { + console.log('Output:', block.text) + } + } + + // Streaming: access events as nodes execute + console.log('\n=== Streaming Graph ===\n') + for await (const event of linearGraph.stream('Explain quantum computing briefly.')) { + if (event.type === 'multiAgentHandoffEvent') { + console.log(`Handoff: ${event.source} -> ${event.targets.join(', ')}`) + } else if (event.type === 'nodeResultEvent') { + console.log(`Node ${event.result.nodeId}: ${event.result.status}`) + } + } +} + +await main().catch(console.error) diff --git a/strands-ts/examples/graph/tsconfig.json b/strands-ts/examples/graph/tsconfig.json new file mode 100644 index 0000000000..0d30dfb862 --- /dev/null +++ b/strands-ts/examples/graph/tsconfig.json @@ -0,0 +1,19 @@ +{ + "compilerOptions": { + "target": "ES2022", + "lib": ["ES2022"], + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": true, + "skipLibCheck": true, + "module": "NodeNext", + "moduleResolution": "NodeNext", + "outDir": "./dist", + "rootDir": "./src", + "declaration": true, + "declarationMap": true, + "sourceMap": true + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "dist", "tests*"] +} diff --git a/strands-ts/examples/mcp/.gitignore b/strands-ts/examples/mcp/.gitignore new file mode 100644 index 0000000000..91a3983f34 --- /dev/null +++ b/strands-ts/examples/mcp/.gitignore @@ -0,0 +1,3 @@ +dist +node_modules +package-lock.json diff --git a/strands-ts/examples/mcp/package.json b/strands-ts/examples/mcp/package.json new file mode 100644 index 0000000000..5827f61056 --- /dev/null +++ b/strands-ts/examples/mcp/package.json @@ -0,0 +1,22 @@ +{ + "name": "first-agent", + "private": true, + "main": "dist/index.js", + "type": "module", + "scripts": { + "prepare": "npm ci --prefix ../../..", + "clean": "rm -rf dist node_modules package-lock.json", + "build": "tsc", + "start": "tsc && node dist/index.js" + }, + "workspaces": [ + "../../" + ], + "dependencies": { + "@strands-agents/sdk": "*" + }, + "devDependencies": { + "@types/node": "^20.0.0", + "typescript": "^5.5.0" + } +} diff --git a/strands-ts/examples/mcp/src/index.ts b/strands-ts/examples/mcp/src/index.ts new file mode 100644 index 0000000000..ba2de5d73e --- /dev/null +++ b/strands-ts/examples/mcp/src/index.ts @@ -0,0 +1,80 @@ +import { Agent, McpClient } from '@strands-agents/sdk' +import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js' +import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' + +async function runInvoke(title: string, agent: Agent, prompt: string) { + console.log(`--- ${title} ---\nUser: ${prompt}`) + const result = await agent.invoke(prompt) + console.log(`\n\n::Invocation complete; stop reason was ${result.stopReason}\n`) +} + +async function main() { + if (!process.env.STRANDS_EXAMPLE_MCP_DEMO) { + console.warn( + 'Skipping MCP client example; STRANDS_EXAMPLE_MCP_DEMO environment variable not set. If you are comfortable with these tools performing side effects than you can set it and re-run the example.' + ) + return + } + + const documentationTools = new McpClient({ + transport: new StdioClientTransport({ + command: 'uvx', + args: ['awslabs.aws-documentation-mcp-server@latest'], + }), + }) + + const agentWithMcpClient = new Agent({ + systemPrompt: + 'You are a helpful assistant that uses the aws-documentation-mcp-server server as a demonstration of mcp functionality. You must only use tools without side effects.', + tools: [documentationTools], + }) + + await runInvoke('1: Invocation with MCP client', agentWithMcpClient, 'Use a random tool from the MCP server.') + + // Set the following environment variable to run the GitHub MCP client example. + // + // STRANDS_EXAMPLE_GITHUB_PAT= + // + // Though unlikely in practice, this can perform side effects when using certain tools. + if (!process.env.STRANDS_EXAMPLE_GITHUB_PAT) { + console.warn( + 'Skipping GitHub MCP client example; STRANDS_EXAMPLE_GITHUB_PAT environment variable not set. Though prompted not to, this can perform side effects when using certain tools.' + ) + await documentationTools.disconnect() + return + } + + // Optional client configuration + const applicationConfig = { + applicationName: 'First Agent Example', + applicationVersion: '0.0.0', + } + + // Create a remote MCP client + const githubMcpClient = new McpClient({ + ...applicationConfig, + transport: new StreamableHTTPClientTransport(new URL('https://api.githubcopilot.com/mcp/'), { + requestInit: { + headers: { + Authorization: `Bearer ${process.env.STRANDS_EXAMPLE_GITHUB_PAT}`, + }, + }, + }), + }) + + const agentWithGithubMcpClient = new Agent({ + systemPrompt: + 'You are a helpful assistant that uses the github_mcp server as a demonstration of mcp functionality. You must only use tools without side effects.', + tools: [githubMcpClient], + }) + + await runInvoke( + '2: Invocation with GitHub MCP client', + agentWithGithubMcpClient, + 'Use a random tool from the GitHub MCP server to illustrate that they work.' + ) + + await Promise.all([documentationTools.disconnect(), githubMcpClient.disconnect()]) +} + +await main().catch(console.error) diff --git a/strands-ts/examples/mcp/tsconfig.json b/strands-ts/examples/mcp/tsconfig.json new file mode 100644 index 0000000000..0d30dfb862 --- /dev/null +++ b/strands-ts/examples/mcp/tsconfig.json @@ -0,0 +1,19 @@ +{ + "compilerOptions": { + "target": "ES2022", + "lib": ["ES2022"], + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": true, + "skipLibCheck": true, + "module": "NodeNext", + "moduleResolution": "NodeNext", + "outDir": "./dist", + "rootDir": "./src", + "declaration": true, + "declarationMap": true, + "sourceMap": true + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "dist", "tests*"] +} diff --git a/strands-ts/examples/swarm/.gitignore b/strands-ts/examples/swarm/.gitignore new file mode 100644 index 0000000000..91a3983f34 --- /dev/null +++ b/strands-ts/examples/swarm/.gitignore @@ -0,0 +1,3 @@ +dist +node_modules +package-lock.json diff --git a/strands-ts/examples/swarm/package.json b/strands-ts/examples/swarm/package.json new file mode 100644 index 0000000000..3890eaddfa --- /dev/null +++ b/strands-ts/examples/swarm/package.json @@ -0,0 +1,22 @@ +{ + "name": "swarm-example", + "private": true, + "main": "dist/index.js", + "type": "module", + "scripts": { + "prepare": "npm ci --prefix ../../..", + "clean": "rm -rf dist node_modules package-lock.json", + "build": "tsc", + "start": "tsc && node dist/index.js" + }, + "workspaces": [ + "../../" + ], + "dependencies": { + "@strands-agents/sdk": "*" + }, + "devDependencies": { + "@types/node": "^20.0.0", + "typescript": "^5.5.0" + } +} diff --git a/strands-ts/examples/swarm/src/index.ts b/strands-ts/examples/swarm/src/index.ts new file mode 100644 index 0000000000..3dd7c71da0 --- /dev/null +++ b/strands-ts/examples/swarm/src/index.ts @@ -0,0 +1,48 @@ +import { Agent, BedrockModel, Swarm } from '@strands-agents/sdk' + +async function main() { + const model = new BedrockModel({ maxTokens: 1024 }) + + // Define swarm agents with descriptions (used for routing decisions) + const researcher = new Agent({ + model, + printer: false, + id: 'researcher', + description: 'Researches a topic and gathers key facts.', + systemPrompt: + 'You are a researcher. Look up the answer, then hand off to the writer agent. Never produce a final response yourself.', + }) + + const writer = new Agent({ + model, + printer: false, + id: 'writer', + description: 'Writes a polished final answer.', + systemPrompt: 'Write the final answer in one clear paragraph. Do not hand off to another agent.', + }) + + // Swarm: researcher hands off to writer via structured output + console.log('=== Swarm Orchestration ===\n') + const swarm = new Swarm({ + nodes: [researcher, writer], + start: 'researcher', + maxSteps: 4, + }) + + const result = await swarm.invoke('What is the largest ocean on Earth?') + console.log('Status:', result.status) + console.log('Agents executed:', result.results.map((r) => r.nodeId).join(' -> ')) + console.log('Output:', result.content.find((b) => b.type === 'textBlock')?.text) + + // Streaming: access handoff events in real-time + console.log('\n=== Streaming Swarm ===\n') + for await (const event of swarm.stream('Explain quantum computing briefly.')) { + if (event.type === 'multiAgentHandoffEvent') { + console.log(`Handoff: ${event.source} -> ${event.targets.join(', ')}`) + } else if (event.type === 'nodeResultEvent') { + console.log(`Node ${event.result.nodeId}: ${event.result.status}`) + } + } +} + +await main().catch(console.error) diff --git a/strands-ts/examples/swarm/tsconfig.json b/strands-ts/examples/swarm/tsconfig.json new file mode 100644 index 0000000000..0d30dfb862 --- /dev/null +++ b/strands-ts/examples/swarm/tsconfig.json @@ -0,0 +1,19 @@ +{ + "compilerOptions": { + "target": "ES2022", + "lib": ["ES2022"], + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": true, + "skipLibCheck": true, + "module": "NodeNext", + "moduleResolution": "NodeNext", + "outDir": "./dist", + "rootDir": "./src", + "declaration": true, + "declarationMap": true, + "sourceMap": true + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "dist", "tests*"] +} diff --git a/strands-ts/examples/telemetry/.gitignore b/strands-ts/examples/telemetry/.gitignore new file mode 100644 index 0000000000..91a3983f34 --- /dev/null +++ b/strands-ts/examples/telemetry/.gitignore @@ -0,0 +1,3 @@ +dist +node_modules +package-lock.json diff --git a/strands-ts/examples/telemetry/README.md b/strands-ts/examples/telemetry/README.md new file mode 100644 index 0000000000..6d85072670 --- /dev/null +++ b/strands-ts/examples/telemetry/README.md @@ -0,0 +1,55 @@ +# Strands Agents — Jaeger Tracing Example + +Send traces from a Strands agent to a local [Jaeger](https://www.jaegertracing.io/) instance and visualize them in the Jaeger UI. + +## Architecture + +```mermaid +flowchart LR + A["Strands Agent
(your code)"] -- OTLP --> B["OTel Collector
(batch + export)
localhost:4318"] + B -- OTLP --> C["Jaeger
(traces)
localhost:16686"] +``` + +The agent exports spans over OTLP HTTP to an OpenTelemetry Collector, which +batches and forwards them to Jaeger. Both the collector and Jaeger run locally +via Docker Compose. + +## Prerequisites + +- Docker (or [Finch](https://github.com/runfinch/finch)) +- Node.js 18+ +- AWS credentials configured (for Bedrock model access) + +## Quick Start + +1. Start Jaeger and the OTel Collector: + +```bash +docker compose up -d +``` + +(Or `finch compose up -d` if using Finch.) + +2. Install dependencies: + +```bash +npm install +``` + +3. Run the example: + +```bash +OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318 npm start +``` + +4. Open the Jaeger UI at [http://localhost:16686](http://localhost:16686). + Select `strands-agents` from the service dropdown and click **Find Traces**. + + You'll see the full trace hierarchy — agent invocation, loop cycles, model + calls, and tool executions nested under each agent span. + +5. Tear down when done: + +```bash +docker compose down +``` \ No newline at end of file diff --git a/strands-ts/examples/telemetry/docker-compose.yml b/strands-ts/examples/telemetry/docker-compose.yml new file mode 100644 index 0000000000..0c22c90d35 --- /dev/null +++ b/strands-ts/examples/telemetry/docker-compose.yml @@ -0,0 +1,20 @@ +services: + # OpenTelemetry Collector — receives spans from the agent and forwards them + # to Jaeger. In production you'd swap/add exporters for your backend of choice. + otel-collector: + image: otel/opentelemetry-collector-contrib:latest + volumes: + - ./otel-collector-config.yaml:/etc/otelcol-contrib/config.yaml + ports: + - "4317:4317" # OTLP gRPC receiver + - "4318:4318" # OTLP HTTP receiver + depends_on: + - jaeger + + # Jaeger — trace visualization UI + jaeger: + image: jaegertracing/all-in-one:latest + ports: + - "16686:16686" # Jaeger UI + environment: + - COLLECTOR_OTLP_ENABLED=true diff --git a/strands-ts/examples/telemetry/otel-collector-config.yaml b/strands-ts/examples/telemetry/otel-collector-config.yaml new file mode 100644 index 0000000000..6d6db0255e --- /dev/null +++ b/strands-ts/examples/telemetry/otel-collector-config.yaml @@ -0,0 +1,31 @@ +# OpenTelemetry Collector configuration +# Receives traces from the Strands agent and exports them to Jaeger for visualization. + +receivers: + otlp: + protocols: + http: + endpoint: 0.0.0.0:4318 + grpc: + endpoint: 0.0.0.0:4317 + +processors: + batch: + timeout: 1s + send_batch_size: 1024 + +exporters: + # Export traces to Jaeger + otlphttp/jaeger: + endpoint: http://jaeger:4318 + + # Log traces to the collector's stdout (useful for debugging the pipeline) + debug: + verbosity: basic + +service: + pipelines: + traces: + receivers: [otlp] + processors: [batch] + exporters: [otlphttp/jaeger, debug] diff --git a/strands-ts/examples/telemetry/package.json b/strands-ts/examples/telemetry/package.json new file mode 100644 index 0000000000..e3c7e9c8ef --- /dev/null +++ b/strands-ts/examples/telemetry/package.json @@ -0,0 +1,28 @@ +{ + "name": "telemetry-example", + "private": true, + "main": "dist/setup-tracer.js", + "type": "module", + "scripts": { + "prepare": "npm ci --prefix ../../..", + "clean": "rm -rf dist node_modules package-lock.json", + "build": "tsc", + "start": "tsc && node dist/setup-tracer.js", + "start:custom-provider": "tsc && node dist/custom-provider.js" + }, + "workspaces": [ + "../../" + ], + "dependencies": { + "@strands-agents/sdk": "*" + }, + "devDependencies": { + "@opentelemetry/api": "^1.9.0", + "@opentelemetry/exporter-trace-otlp-http": "^0.57.2", + "@opentelemetry/resources": "^1.30.1", + "@opentelemetry/sdk-trace-base": "^1.30.1", + "@opentelemetry/sdk-trace-node": "^1.30.1", + "@types/node": "^20.0.0", + "typescript": "^5.5.0" + } +} diff --git a/strands-ts/examples/telemetry/src/custom-provider.ts b/strands-ts/examples/telemetry/src/custom-provider.ts new file mode 100644 index 0000000000..8224105caf --- /dev/null +++ b/strands-ts/examples/telemetry/src/custom-provider.ts @@ -0,0 +1,114 @@ +/** + * Telemetry example using your own NodeTracerProvider. + * + * Use this approach when you need full control over the OpenTelemetry setup — + * for example, to add custom span processors, use a specific resource + * configuration, or integrate with an existing observability pipeline. + * + * The Agent class uses the global OTel API (`trace.getTracer(...)`) internally, + * so any provider registered via `provider.register()` is automatically picked + * up — no need to pass it to the SDK. + * + * Run with OTLP exporter (e.g. Jaeger at localhost:4318): + * OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318 npm run start:custom-provider + * + * Run with console exporter for local debugging: + * npm run start:custom-provider + */ + +import { Agent, tool } from '@strands-agents/sdk' +import { z } from 'zod' + +// OpenTelemetry imports — you manage these directly +import { Resource } from '@opentelemetry/resources' +import { NodeTracerProvider } from '@opentelemetry/sdk-trace-node' +import { SimpleSpanProcessor, BatchSpanProcessor } from '@opentelemetry/sdk-trace-base' +import { ConsoleSpanExporter } from '@opentelemetry/sdk-trace-node' +import { OTLPTraceExporter } from '@opentelemetry/exporter-trace-otlp-http' + +// 1. Create your own Resource with custom attributes. +const resource = new Resource({ + 'service.name': 'my-custom-app', + 'service.version': '2.0.0', + 'service.namespace': 'my-team', + 'deployment.environment': 'staging', + 'custom.attribute': 'hello-from-custom-provider', +}) + +// 2. Create and configure your own NodeTracerProvider. +const provider = new NodeTracerProvider({ resource }) + +// 3. Add span processors / exporters as needed. +// - ConsoleSpanExporter: prints spans to stdout (useful for debugging) +// - OTLPTraceExporter: sends spans to an OTLP-compatible backend +provider.addSpanProcessor(new SimpleSpanProcessor(new ConsoleSpanExporter())) + +if (process.env.OTEL_EXPORTER_OTLP_ENDPOINT) { + provider.addSpanProcessor(new BatchSpanProcessor(new OTLPTraceExporter())) +} + +// 4. Register the provider globally. +// This sets up the global tracer provider, context manager, and propagators. +// The Strands Agent will automatically pick it up via `trace.getTracer(...)`. +provider.register() + +console.log('=== Custom Provider Resource Attributes ===\n') +for (const [key, value] of Object.entries(provider.resource.attributes)) { + console.log(` ${key}: ${value}`) +} +console.log('') + +// 5. Define tools as usual — nothing changes on the application side. +const calculateTool = tool({ + name: 'calculate', + description: 'Perform a basic arithmetic calculation.', + inputSchema: z.object({ + expression: z.string().describe('A math expression, e.g., "2 + 2"'), + }), + callback: (input) => { + try { + // Simple eval for demo purposes only + const result = Function(`"use strict"; return (${input.expression})`)() + return `${input.expression} = ${result}` + } catch { + return `Could not evaluate: ${input.expression}` + } + }, +}) + +const greetTool = tool({ + name: 'greet', + description: 'Generate a greeting for a person.', + inputSchema: z.object({ + name: z.string().describe('The name of the person to greet'), + }), + callback: (input) => { + return `Hello, ${input.name}! Welcome aboard.` + }, +}) + +async function main() { + // 6. Create an agent — it automatically uses your custom provider. + const agent = new Agent({ + name: 'custom-traced-agent', + systemPrompt: + 'You are a helpful assistant. Use the calculate tool for math questions and the greet tool to greet people.', + tools: [calculateTool, greetTool], + traceAttributes: { + 'app.example': 'custom-provider', + }, + }) + + console.log('=== Invoking Agent ===\n') + const result = await agent.invoke('Please greet Alice, then calculate 42 * 17 for me.') + console.log(`\nStop reason: ${result.stopReason}`) + + // 7. Flush and shut down the provider when done. + // This ensures all buffered spans are exported before the process exits. + await provider.forceFlush() + await provider.shutdown() + + console.log('\nDone! Check your observability backend for traces.') +} + +await main().catch(console.error) diff --git a/strands-ts/examples/telemetry/src/setup-tracer.ts b/strands-ts/examples/telemetry/src/setup-tracer.ts new file mode 100644 index 0000000000..d7190537bf --- /dev/null +++ b/strands-ts/examples/telemetry/src/setup-tracer.ts @@ -0,0 +1,106 @@ +/** + * Telemetry example using the built-in setupTracer() helper. + * + * This is the recommended approach for most use cases. The SDK creates and + * configures a NodeTracerProvider internally, and the Agent automatically + * traces all invocations, model calls, and tool executions. + * + * Run with OTLP exporter (e.g. Jaeger at localhost:4318): + * OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318 npm start + * + * Run with console exporter for local debugging: + * npm start + * + * Customize resource attributes: + * OTEL_SERVICE_NAME=my-app \ + * OTEL_RESOURCE_ATTRIBUTES="service.version=1.0.0,team=platform" \ + * npm start + */ + +import { Agent, tool } from '@strands-agents/sdk' +import { setupTracer } from '@strands-agents/sdk/telemetry' +import { z } from 'zod' + +// 1. Set up telemetry ONCE at application start. +// setupTracer() creates a NodeTracerProvider with sensible defaults and +// registers it globally. All agents will automatically pick it up. +const provider = setupTracer({ + exporters: { + // Send spans to an OTLP-compatible backend (Jaeger, Grafana, etc.) + // Uses OTEL_EXPORTER_OTLP_ENDPOINT env var for the endpoint. + otlp: true, + // Also print spans to the console for local debugging. + console: true, + }, +}) + +// You can inspect the resource attributes that will be attached to all spans. +console.log('=== Resource Attributes ===\n') +for (const [key, value] of Object.entries(provider.resource.attributes)) { + console.log(` ${key}: ${value}`) +} +console.log('') + +// 2. Define tools as usual +const getWeather = tool({ + name: 'get_weather', + description: 'Get the current weather for a specific location.', + inputSchema: z.object({ + location: z.string().describe('The city and state, e.g., San Francisco, CA'), + }), + callback: (input) => { + return `The weather in ${input.location} is 72°F and sunny.` + }, +}) + +const getTime = tool({ + name: 'get_time', + description: 'Get the current time for a timezone.', + inputSchema: z.object({ + timezone: z.string().describe('The timezone, e.g., America/New_York'), + }), + callback: (input) => { + return `The current time in ${input.timezone} is 3:00 PM.` + }, +}) + +async function main() { + // 3. Create agents — telemetry is automatically active. + // Use `name` and `traceAttributes` for richer trace metadata. + const weatherAgent = new Agent({ + name: 'weather-agent', + systemPrompt: 'You are a helpful weather assistant. Use the get_weather tool to answer questions.', + tools: [getWeather], + traceAttributes: { 'app.module': 'weather' }, + }) + + const timeAgent = new Agent({ + name: 'time-agent', + systemPrompt: 'You are a helpful time assistant. Use the get_time tool to answer questions.', + tools: [getTime], + traceAttributes: { 'app.module': 'time' }, + }) + + // 4. Invoke agents — each creates its own trace with nested spans for + // agent invocation, loop cycles, model calls, and tool executions. + console.log('=== Running Weather Agent ===\n') + const weatherResult = await weatherAgent.invoke('What is the weather in Seattle?') + console.log(`\nWeather agent stop reason: ${weatherResult.stopReason}\n`) + + console.log('=== Running Time Agent ===\n') + const timeResult = await timeAgent.invoke('What time is it in Tokyo?') + console.log(`\nTime agent stop reason: ${timeResult.stopReason}\n`) + + // 5. Agents can also run concurrently — traces remain isolated. + console.log('=== Running Both Agents Concurrently ===\n') + const [concurrentWeather, concurrentTime] = await Promise.all([ + weatherAgent.invoke('What is the weather in New York?'), + timeAgent.invoke('What time is it in London?'), + ]) + + console.log(`\nConcurrent weather stop reason: ${concurrentWeather.stopReason}`) + console.log(`Concurrent time stop reason: ${concurrentTime.stopReason}`) + console.log('\nDone! Check your observability backend for traces.') +} + +await main().catch(console.error) diff --git a/strands-ts/examples/telemetry/tsconfig.json b/strands-ts/examples/telemetry/tsconfig.json new file mode 100644 index 0000000000..0d30dfb862 --- /dev/null +++ b/strands-ts/examples/telemetry/tsconfig.json @@ -0,0 +1,19 @@ +{ + "compilerOptions": { + "target": "ES2022", + "lib": ["ES2022"], + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": true, + "skipLibCheck": true, + "module": "NodeNext", + "moduleResolution": "NodeNext", + "outDir": "./dist", + "rootDir": "./src", + "declaration": true, + "declarationMap": true, + "sourceMap": true + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "dist", "tests*"] +} diff --git a/strands-ts/package.json b/strands-ts/package.json new file mode 100644 index 0000000000..4060223f96 --- /dev/null +++ b/strands-ts/package.json @@ -0,0 +1,262 @@ +{ + "name": "@strands-agents/sdk", + "version": "0.0.1-development", + "description": "TypeScript SDK for Strands Agents framework", + "main": "dist/src/index.js", + "module": "dist/src/index.js", + "types": "dist/src/index.d.ts", + "type": "module", + "files": [ + "dist", + "README.md", + "LICENSE" + ], + "exports": { + ".": { + "types": "./dist/src/index.d.ts", + "default": "./dist/src/index.js" + }, + "./models/anthropic": { + "types": "./dist/src/models/anthropic.d.ts", + "default": "./dist/src/models/anthropic.js" + }, + "./models/openai": { + "types": "./dist/src/models/openai/index.d.ts", + "default": "./dist/src/models/openai/index.js" + }, + "./models/bedrock": { + "types": "./dist/src/models/bedrock.d.ts", + "default": "./dist/src/models/bedrock.js" + }, + "./models/google": { + "types": "./dist/src/models/google/index.d.ts", + "default": "./dist/src/models/google/index.js" + }, + "./models/vercel": { + "types": "./dist/src/models/vercel.d.ts", + "default": "./dist/src/models/vercel.js" + }, + "./multiagent": { + "types": "./dist/src/multiagent/index.d.ts", + "default": "./dist/src/multiagent/index.js" + }, + "./vended-tools/notebook": { + "types": "./dist/src/vended-tools/notebook/index.d.ts", + "default": "./dist/src/vended-tools/notebook/index.js" + }, + "./vended-tools/file-editor": { + "types": "./dist/src/vended-tools/file-editor/index.d.ts", + "default": "./dist/src/vended-tools/file-editor/index.js" + }, + "./vended-tools/http-request": { + "types": "./dist/src/vended-tools/http-request/index.d.ts", + "default": "./dist/src/vended-tools/http-request/index.js" + }, + "./vended-tools/bash": { + "types": "./dist/src/vended-tools/bash/index.d.ts", + "default": "./dist/src/vended-tools/bash/index.js" + }, + "./a2a": { + "types": "./dist/src/a2a/index.d.ts", + "default": "./dist/src/a2a/index.js" + }, + "./a2a/express": { + "types": "./dist/src/a2a/express-server.d.ts", + "default": "./dist/src/a2a/express-server.js" + }, + "./session/s3-storage": { + "types": "./dist/src/session/s3-storage.d.ts", + "default": "./dist/src/session/s3-storage.js" + }, + "./telemetry": { + "types": "./dist/src/telemetry/index.d.ts", + "default": "./dist/src/telemetry/index.js" + }, + "./vended-plugins/skills": { + "types": "./dist/src/vended-plugins/skills/index.d.ts", + "default": "./dist/src/vended-plugins/skills/index.js" + }, + "./vended-plugins/context-offloader": { + "types": "./dist/src/vended-plugins/context-offloader/index.d.ts", + "default": "./dist/src/vended-plugins/context-offloader/index.js" + }, + "./vended-interventions/hitl": { + "types": "./dist/src/vended-interventions/hitl/index.d.ts", + "default": "./dist/src/vended-interventions/hitl/index.js" + }, + "./vended-interventions/steering": { + "types": "./dist/src/vended-interventions/steering/index.d.ts", + "default": "./dist/src/vended-interventions/steering/index.js" + }, + "./vended-tools": { + "types": "./dist/src/vended-tools/index.d.ts", + "default": "./dist/src/vended-tools/index.js" + }, + "./vended-plugins": { + "types": "./dist/src/vended-plugins/index.d.ts", + "default": "./dist/src/vended-plugins/index.js" + } + }, + "scripts": { + "build": "tsc --project src/tsconfig.json", + "prepack": "npm run build && cp ../README.md . && cp ../LICENSE.APACHE LICENSE", + "postpack": "rm -f README.md LICENSE", + "check": "npm run lint && npm run format && npm run type-check && npm run check:browser-bundle && npm run test:coverage && npm run test:package", + "check:browser-bundle": "esbuild src/index.ts --bundle --platform=browser --format=esm --packages=external --outfile=/dev/null", + "clean": "rm -rf node_modules dist", + "lock:refresh": "rm -rf node_modules && npm install --ignore-scripts --os=linux --os=darwin --os=win32 --cpu=x64 --cpu=arm64 --cpu=wasm32", + "test": "vitest run --project unit-node", + "test:watch": "vitest --project unit-node", + "test:coverage": "vitest run --coverage --project unit-node", + "test:types": "vitest run --project types", + "test:integ": "vitest run --project integ-node", + "test:integ:browser": "vitest run --project integ-browser", + "test:integ:all": "vitest run --project integ-node --project integ-browser", + "test:browser": "vitest run --project unit-browser", + "test:browser:install": "npx playwright install --with-deps chromium", + "test:all": "vitest run --project unit-node --project unit-browser", + "test:all:coverage": "vitest run --coverage --project unit-node --project unit-browser", + "test:package": "cd test/packages/esm-module && npm install && node esm.js && cd ../cjs-module && npm install && node cjs.js", + "lint": "eslint src test/integ", + "lint:fix": "eslint src test/integ --fix", + "format": "prettier --write src test/integ", + "format:check": "prettier --check src test/integ", + "type-check": "tsc --noEmit --project src/tsconfig.json && tsc --noEmit --project test/integ/tsconfig.json", + "type-check:watch": "tsc --noEmit --watch" + }, + "keywords": [ + "agents", + "ai", + "typescript", + "sdk", + "strands" + ], + "author": "Strands Agents", + "license": "Apache-2.0", + "devDependencies": { + "@a2a-js/sdk": "^0.3.10", + "@ai-sdk/amazon-bedrock": "^4.0.77", + "@ai-sdk/openai": "^3.0.41", + "@ai-sdk/provider": "^3.0.0", + "@anthropic-ai/sdk": "^0.92.0", + "@aws-sdk/client-bedrock": "^3.943.0", + "@aws-sdk/client-s3": "^3.943.0", + "@aws-sdk/client-secrets-manager": "^3.943.0", + "@aws-sdk/client-sts": "^3.996.0", + "@aws-sdk/credential-providers": "^3.943.0", + "@aws/bedrock-token-generator": "^1.1.0", + "@eslint/js": "^9.39.4", + "@google/genai": "^1.40.0", + "@opentelemetry/api": "^1.9.0", + "@opentelemetry/exporter-metrics-otlp-http": "^0.214.0", + "@opentelemetry/exporter-trace-otlp-http": "^0.214.0", + "@opentelemetry/resources": "^2.6.1", + "@opentelemetry/sdk-metrics": "^2.6.1", + "@opentelemetry/sdk-trace-base": "^2.6.1", + "@opentelemetry/sdk-trace-node": "^2.6.1", + "@smithy/types": "^4.0.0", + "@types/express": "^5.0.6", + "@types/node": "^25.6.0", + "@types/uuid": "^11.0.0", + "@typescript-eslint/eslint-plugin": "^8.48.1", + "@typescript-eslint/parser": "^8.0.0", + "@vitest/browser": "^4.0.15", + "@vitest/browser-playwright": "^4.0.15", + "@vitest/coverage-v8": "^4.0.15", + "eslint": "^10.2.0", + "eslint-plugin-tsdoc": "^0.5.0", + "express": "^5.2.1", + "openai": "^6.7.0", + "playwright": "^1.60.0", + "tsx": "^4.21.0", + "typescript": "^6.0.2", + "vitest": "^4.0.8" + }, + "engines": { + "node": ">=20.0.0" + }, + "repository": { + "type": "git", + "url": "git+https://github.com/strands-agents/sdk-typescript.git" + }, + "bugs": { + "url": "https://github.com/strands-agents/sdk-typescript/issues" + }, + "homepage": "https://github.com/strands-agents/sdk-typescript#readme", + "dependencies": { + "@aws-sdk/client-bedrock-runtime": "^3.1037.0", + "@types/json-schema": "^7.0.15", + "uuid": "^14.0.0", + "yaml": "^2.8.3" + }, + "peerDependencies": { + "@a2a-js/sdk": "^0.3.10", + "@ai-sdk/provider": "^3.0.0", + "@anthropic-ai/sdk": "^0.92.0", + "@aws-sdk/client-s3": "^3.943.0", + "@aws/bedrock-token-generator": "^1.1.0", + "@google/genai": "^1.40.0", + "@modelcontextprotocol/sdk": "^1.25.2", + "@opentelemetry/api": "^1.9.0", + "@opentelemetry/exporter-metrics-otlp-http": "^0.214.0", + "@opentelemetry/exporter-trace-otlp-http": "^0.214.0", + "@opentelemetry/resources": "^2.6.1", + "@opentelemetry/sdk-metrics": "^2.6.1", + "@opentelemetry/sdk-trace-base": "^2.6.1", + "@opentelemetry/sdk-trace-node": "^2.6.1", + "@smithy/types": "^4.0.0", + "express": "^5.1.0", + "openai": "^6.7.0", + "zod": "^4.1.12" + }, + "peerDependenciesMeta": { + "@a2a-js/sdk": { + "optional": true + }, + "@ai-sdk/provider": { + "optional": true + }, + "@anthropic-ai/sdk": { + "optional": true + }, + "@aws/bedrock-token-generator": { + "optional": true + }, + "@smithy/types": { + "optional": true + }, + "express": { + "optional": true + }, + "@aws-sdk/client-s3": { + "optional": true + }, + "@google/genai": { + "optional": true + }, + "openai": { + "optional": true + }, + "@opentelemetry/sdk-trace-base": { + "optional": true + }, + "@opentelemetry/sdk-trace-node": { + "optional": true + }, + "@opentelemetry/sdk-metrics": { + "optional": true + }, + "@opentelemetry/resources": { + "optional": true + }, + "@opentelemetry/exporter-trace-otlp-http": { + "optional": true + }, + "@opentelemetry/exporter-metrics-otlp-http": { + "optional": true + } + }, + "overrides": { + "fast-xml-parser": ">=5.3.6" + } +} diff --git a/strands-ts/src/__fixtures__/agent-helpers.ts b/strands-ts/src/__fixtures__/agent-helpers.ts new file mode 100644 index 0000000000..53ed7eba0b --- /dev/null +++ b/strands-ts/src/__fixtures__/agent-helpers.ts @@ -0,0 +1,238 @@ +/** + * Test fixtures and helpers for Agent testing. + * This module provides utilities for testing Agent-related implementations. + */ + +import { expect } from 'vitest' +import type { Agent } from '../agent/agent.js' +import type { AgentResult, InvokableAgent, InvokeArgs, InvokeOptions } from '../types/agent.js' +import type { StopReason } from '../types/messages.js' +import { Message, TextBlock } from '../types/messages.js' +import type { Role } from '../types/messages.js' +import { StateStore } from '../state-store.js' +import type { JSONValue } from '../types/json.js' +import { ToolRegistry } from '../registry/tool-registry.js' +import type { HookableEvent, StreamEvent } from '../hooks/events.js' +import type { HookableEventConstructor, HookCallback } from '../hooks/types.js' +import { expectLoopMetrics, type LoopMetricsMatcher } from './metrics-helpers.js' + +/** + * A hook registration captured by the mock agent's addHook. + */ +export type TrackedHook = { + eventType: HookableEventConstructor + callback: HookCallback +} + +/** + * Data for creating a mock Agent. + */ +export interface MockAgentData { + /** + * Messages for the agent. + */ + messages?: Message[] + /** + * Initial state for the agent. + */ + appState?: Record + /** + * Optional tool registry for the agent. + */ + toolRegistry?: ToolRegistry + /** + * Additional properties to spread onto the mock agent. + */ + extra?: Partial +} + +/** + * A mock Agent with a `trackedHooks` array populated by `addHook` calls. + */ +export type MockAgent = Agent & { trackedHooks: TrackedHook[] } + +/** + * Helper to create a mock Agent for testing. + * Provides minimal Agent interface with messages, appState, and tool registry. + * `addHook` captures registrations into `trackedHooks` for test inspection. + * + * @param data - Optional mock agent data + * @returns Mock Agent with trackedHooks + */ +export function createMockAgent(data?: MockAgentData): MockAgent { + const trackedHooks: TrackedHook[] = [] + return { + messages: data?.messages ?? [], + appState: new StateStore(data?.appState ?? {}), + modelState: new StateStore(), + toolRegistry: data?.toolRegistry ?? new ToolRegistry(), + cancelSignal: new AbortController().signal, + addHook: (eventType: HookableEventConstructor, callback: HookCallback) => { + trackedHooks.push({ + eventType: eventType as HookableEventConstructor, + callback: callback as HookCallback, + }) + return () => {} + }, + ...data?.extra, + trackedHooks, + } as unknown as MockAgent +} + +/** + * Creates a Message with the given role containing a single TextBlock. + * + * @param role - The message role + * @param text - The text content + * @returns A Message with the specified role + */ +export function textMessage(role: Role, text: string): Message { + return new Message({ role, content: [new TextBlock(text)] }) +} + +/** + * Finds the tracked hook for the given event type and invokes it with the provided event. + * Throws if no hook is registered for that event type. + * + * @param agent - The mock agent with tracked hooks + * @param event - The event instance to dispatch + */ +export async function invokeTrackedHook(agent: MockAgent, event: T): Promise { + const hook = agent.trackedHooks.find((h) => h.eventType === event.constructor) + if (!hook) { + throw new Error(`No hook registered for event type: ${event.constructor.name}`) + } + await hook.callback(event) +} + +/** + * Options for building an AgentResult matcher. + */ +export interface AgentResultMatcher extends Omit { + /** + * Expected stop reason from the final model response. + */ + stopReason: StopReason + + /** + * Expected text content in the last assistant message's TextBlock. + * When provided, asserts exact text match in a TextBlock with role 'assistant'. + * When omitted, only validates lastMessage exists with role 'assistant'. + */ + messageText?: string + + /** + * Expected number of agent loop cycles. + */ + cycleCount: number + + /** + * Expected number of traces. When provided, asserts exact array length. + * When omitted, asserts traces array exists with at least one element. + */ + traceCount?: number + + /** + * Expected `invocationState` on the result. When provided, the full object + * must match exactly — extra keys fail. When omitted, only asserts + * `invocationState` is present (any object). + */ + invocationState?: Record +} + +/** + * Creates an asymmetric matcher that validates AgentResult structure and values. + * Reduces nesting in test assertions by providing a clean, readable matcher. + * + * @param options - Expected result values + * @returns An asymmetric matcher suitable for use in expect().toEqual() + * + * @example + * ```typescript + * expect(result).toEqual(expectAgentResult({ + * stopReason: 'endTurn', + * messageText: 'Hello', + * cycleCount: 1, + * })) + * ``` + */ +export function expectAgentResult(options: AgentResultMatcher): AgentResult { + const { stopReason, messageText, cycleCount, traceCount, toolNames, usage, invocationState } = options + + const expectedLastMessage = messageText + ? expect.objectContaining({ + role: 'assistant', + content: expect.arrayContaining([expect.objectContaining({ type: 'textBlock', text: messageText })]), + }) + : expect.objectContaining({ role: 'assistant' }) + + const expectedTraces = + traceCount !== undefined + ? expect.objectContaining({ length: traceCount }) + : expect.arrayContaining([expect.objectContaining({ name: expect.any(String) })]) + + // Build metrics matcher options, only including defined properties + const metricsOptions: LoopMetricsMatcher = { cycleCount } + if (toolNames !== undefined) { + metricsOptions.toolNames = toolNames + } + if (usage !== undefined) { + metricsOptions.usage = usage + } + + return expect.objectContaining({ + type: 'agentResult', + stopReason, + lastMessage: expectedLastMessage, + metrics: expectLoopMetrics(metricsOptions), + traces: expectedTraces, + invocationState: invocationState ?? expect.any(Object), + }) as AgentResult +} + +/** + * Creates a minimal InvokableAgent that sleeps for `delayMs` before resolving, + * aborting the sleep early when the invocation's `cancelSignal` fires. Used to + * exercise timeout and cancellation behavior deterministically without spinning + * up a full Agent. + * + * @param id - The agent id + * @param delayMs - How long the agent should sleep before returning + * @param structuredOutput - Optional structured output (e.g. a swarm handoff). When present, + * its `message` field is used as the assistant text. + */ +export function createCancellableAgent( + id: string, + delayMs: number, + structuredOutput: { agentId?: string; message: string } = { message: 'done' } +): InvokableAgent { + const sleep = (signal?: AbortSignal): Promise => + new Promise((resolve, reject) => { + const timer = setTimeout(resolve, delayMs) + if (signal) { + const onAbort = (): void => { + clearTimeout(timer) + reject(new Error('cancelled')) + } + if (signal.aborted) onAbort() + else signal.addEventListener('abort', onAbort, { once: true }) + } + }) + + return { + id, + description: `Agent ${id}`, + async invoke(_args: InvokeArgs, options?: InvokeOptions): Promise { + await sleep(options?.cancelSignal) + return { + stopReason: 'endTurn', + lastMessage: { role: 'assistant', content: [new TextBlock(structuredOutput.message)] }, + structuredOutput, + } as AgentResult + }, + // eslint-disable-next-line require-yield + async *stream(args: InvokeArgs, options?: InvokeOptions): AsyncGenerator { + return await this.invoke(args, options) + }, + } +} diff --git a/strands-ts/src/__fixtures__/environment.ts b/strands-ts/src/__fixtures__/environment.ts new file mode 100644 index 0000000000..6ce65b86ff --- /dev/null +++ b/strands-ts/src/__fixtures__/environment.ts @@ -0,0 +1,14 @@ +/** + * Environment detection utilities for tests + */ + +/** + * Detects if the current environment is Node.js + */ +export const isNode = + typeof process !== 'undefined' && typeof process.versions !== 'undefined' && !!process.versions.node + +/** + * Detects if the current environment is a browser + */ +export const isBrowser = typeof window !== 'undefined' diff --git a/strands-ts/src/__fixtures__/metrics-helpers.ts b/strands-ts/src/__fixtures__/metrics-helpers.ts new file mode 100644 index 0000000000..bdb84a5f2e --- /dev/null +++ b/strands-ts/src/__fixtures__/metrics-helpers.ts @@ -0,0 +1,90 @@ +/** + * Test helpers for asserting on AgentMetrics in agent tests. + */ + +import { expect } from 'vitest' +import type { Usage } from '../models/streaming.js' +import { AgentMetrics } from '../telemetry/meter.js' + +/** + * Options for building an AgentMetrics matcher. + */ +export interface LoopMetricsMatcher { + /** + * Expected number of agent loop cycles. + */ + cycleCount: number + + /** + * Expected tool names that were invoked. + */ + toolNames?: string[] + + /** + * Expected accumulated token usage. When provided, asserts exact values. + * When omitted, asserts the shape with expect.any(Number). + */ + usage?: Usage +} + +/** + * Creates an asymmetric matcher that validates AgentMetrics structure and values. + * + * @param options - Expected metric values + * @returns An asymmetric matcher suitable for use in expect().toEqual() + */ +export function expectLoopMetrics(options: LoopMetricsMatcher): AgentMetrics { + const { cycleCount, toolNames = [], usage } = options + + const expectedToolMetrics: Record = {} + for (const name of toolNames) { + expectedToolMetrics[name] = { + callCount: expect.any(Number), + successCount: expect.any(Number), + errorCount: expect.any(Number), + totalTime: expect.any(Number), + } + } + + const expectedUsage = + usage ?? + expect.objectContaining({ + inputTokens: expect.any(Number), + outputTokens: expect.any(Number), + totalTokens: expect.any(Number), + }) + + return expect.objectContaining({ + cycleCount, + toolMetrics: toolNames.length > 0 ? expect.objectContaining(expectedToolMetrics) : {}, + accumulatedUsage: expectedUsage, + accumulatedMetrics: { latencyMs: expect.any(Number) }, + }) as AgentMetrics +} + +/** + * Finds the latest data point value for a named metric from OTEL ResourceMetrics. + * + * Flattens the ResourceMetrics → ScopeMetrics → MetricData hierarchy and + * returns the value of the last data point for the matching metric name. + * For counters this is a number; for histograms it is an object with + * sum, count, min, max, etc. + * + * @param resourceMetrics - Array of ResourceMetrics from an InMemoryMetricExporter + * @param metricName - The metric descriptor name to search for + * @returns The value of the last data point, or undefined if not found + */ +export function findMetricValue( + resourceMetrics: { + scopeMetrics: { metrics: { descriptor: { name: string }; dataPoints: { value: unknown }[] }[] }[] + }[], + metricName: string +): unknown { + const dp = resourceMetrics + .flatMap((rm) => rm.scopeMetrics) + .flatMap((sm) => sm.metrics) + .filter((m) => m.descriptor.name === metricName) + .flatMap((m) => m.dataPoints) + .at(-1) + return dp?.value +} diff --git a/strands-ts/src/__fixtures__/mock-message-model.ts b/strands-ts/src/__fixtures__/mock-message-model.ts new file mode 100644 index 0000000000..5962ec182a --- /dev/null +++ b/strands-ts/src/__fixtures__/mock-message-model.ts @@ -0,0 +1,303 @@ +/** + * Test message model provider for simplified agent testing. + * This module provides a content-focused test model that generates appropriate + * ModelStreamEvents from ContentBlock objects, eliminating the need to manually + * construct events in tests. + */ + +import { Model } from '../models/model.js' +import type { Message, StopReason } from '../types/messages.js' +import type { ModelStreamEvent, Usage } from '../models/streaming.js' +import type { BaseModelConfig, StreamOptions } from '../models/model.js' +import type { PlainContentBlock } from './slim-types.js' + +/** + * Input type for addTurn - accepts plain objects or class instances. + */ +type ContentBlockInput = PlainContentBlock | PlainContentBlock[] | Error + +/** + * Represents a single turn in the test sequence. + * Can be either content blocks with stopReason, or an Error to throw. + */ +type Turn = + | { type: 'content'; content: PlainContentBlock[]; stopReason: StopReason; usage?: Usage } + | { type: 'error'; error: Error } + +/** + * Test model provider that operates at the content block level. + * Simplifies agent loop tests by allowing specification of content blocks + * instead of manually yielding individual ModelStreamEvents. + */ +export class MockMessageModel extends Model { + private _turns: Turn[] + private _currentTurnIndex: number + private _config: BaseModelConfig + + /** + * Creates a new MockMessageModel. + */ + constructor() { + super() + this._config = { modelId: 'test-model' } + this._currentTurnIndex = 0 + this._turns = [] + } + + /** + * The number of turns that have been invoked thus far. + */ + get callCount(): number { + return this._currentTurnIndex + } + + /** + * Adds a turn to the test sequence. + * Returns this for method chaining. + * + * @param turn - ContentBlock, ContentBlock[], or Error to add + * @param options - Optional stop reason and token usage + * @returns This provider for chaining + * + * @example + * ```typescript + * provider + * .addTurn({ type: 'textBlock', text: 'Hello' }) // Single block + * .addTurn([{ type: 'toolUseBlock', ... }]) // Array of blocks + * .addTurn({ type: 'textBlock', text: 'Done' }, { stopReason: 'maxTokens' }) // Explicit stopReason + * .addTurn({ type: 'textBlock', text: 'Hi' }, { usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } }) + * .addTurn(new Error('Failed')) // Error turn + * ``` + */ + addTurn(turn: ContentBlockInput, options?: { stopReason?: StopReason; usage?: Usage }): this { + this._turns.push(this._createTurn(turn, options?.stopReason, options?.usage)) + return this + } + + /** + * Updates the model configuration. + * + * @param modelConfig - Configuration to merge with existing config + */ + updateConfig(modelConfig: BaseModelConfig): void { + this._config = { ...this._config, ...modelConfig } + } + + /** + * Retrieves the current model configuration. + * + * @returns Current configuration object + */ + getConfig(): BaseModelConfig { + return this._config + } + + /** + * Streams a conversation with the model. + * Generates appropriate ModelStreamEvents from the content blocks. + * + * Single-turn behavior: Reuses the same turn indefinitely + * Multi-turn behavior: Advances through turns and throws when exhausted + * + * @param _messages - Conversation messages (ignored by test provider) + * @param _options - Streaming options (ignored by test provider) + * @returns Async iterable of ModelStreamEvents + */ + async *stream(_messages: Message[], _options?: StreamOptions): AsyncGenerator { + // Determine which turn index to use + // For single turn, always use 0. For multiple turns, use current index + const turnIndex = this._turns.length === 1 ? 0 : this._currentTurnIndex + + // Advance turn index immediately for multi-turn scenarios + // This ensures that the next call to stream() will use the next turn + if (this._turns.length > 1) { + this._currentTurnIndex++ + } + + // Check if we've exhausted all turns (after potential increment) + if (turnIndex >= this._turns.length) { + throw new Error('All turns have been consumed') + } + + // Get the current turn + const turn = this._turns[turnIndex]! + + // Handle error turns + if (turn.type === 'error') { + throw turn.error + } + + // Generate events for content turn + yield* this._generateEventsForContent(turn.content, turn.stopReason, turn.usage) + } + + /** + * Generates appropriate ModelStreamEvents for content blocks. + * All messages have role 'assistant' since this is for testing model responses. + */ + private async *_generateEventsForContent( + content: PlainContentBlock[], + stopReason: StopReason, + usage?: Usage + ): AsyncGenerator { + // Yield message start event (always assistant role) + yield { type: 'modelMessageStartEvent', role: 'assistant' } + + // Yield events for each content block + for (let i = 0; i < content.length; i++) { + const block = content[i]! + yield* this._generateEventsForBlock(block) + } + + // Yield message stop event + yield { type: 'modelMessageStopEvent', stopReason } + + // Yield metadata event with token usage when provided + if (usage) { + yield { type: 'modelMetadataEvent', usage } + } + } + + /** + * Creates a Turn object from ContentBlock(s) or Error. + */ + private _createTurn(turn: ContentBlockInput, explicitStopReason?: StopReason, usage?: Usage): Turn { + if (turn instanceof Error) { + return { type: 'error', error: turn } + } + + // Normalize to array + const content = Array.isArray(turn) ? turn : [turn] + + return { + type: 'content', + content, + stopReason: explicitStopReason ?? this._deriveStopReason(content), + ...(usage !== undefined && { usage }), + } + } + + /** + * Auto-derives stopReason from content blocks. + * Returns 'toolUse' if content contains any ToolUseBlock, otherwise 'endTurn'. + */ + private _deriveStopReason(content: PlainContentBlock[]): StopReason { + const hasToolUse = content.some((block) => block.type === 'toolUseBlock') + return hasToolUse ? 'toolUse' : 'endTurn' + } + + /** + * Generates appropriate ModelStreamEvents for a message. + */ + private async *_generateEventsForMessage(message: Message, stopReason: StopReason): AsyncGenerator { + // Yield message start event + yield { type: 'modelMessageStartEvent', role: message.role } + + // Yield events for each content block + for (let i = 0; i < message.content.length; i++) { + const block = message.content[i]! + yield* this._generateEventsForBlock(block) + } + + // Yield message stop event + yield { type: 'modelMessageStopEvent', stopReason } + } + + /** + * Generates appropriate ModelStreamEvents for a content block. + */ + private async *_generateEventsForBlock(block: PlainContentBlock): AsyncGenerator { + switch (block.type) { + case 'textBlock': + yield { type: 'modelContentBlockStartEvent' } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: block.text }, + } + yield { type: 'modelContentBlockStopEvent' } + break + + case 'toolUseBlock': + yield { + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', name: block.name, toolUseId: block.toolUseId }, + } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: JSON.stringify(block.input) }, + } + yield { type: 'modelContentBlockStopEvent' } + break + + case 'reasoningBlock': { + yield { type: 'modelContentBlockStartEvent' } + // Build delta object with only defined properties + const delta: { + type: 'reasoningContentDelta' + text?: string + signature?: string + redactedContent?: Uint8Array + } = { + type: 'reasoningContentDelta', + } + if (block.text !== undefined) { + delta.text = block.text + } + if (block.signature !== undefined) { + delta.signature = block.signature + } + if (block.redactedContent !== undefined) { + delta.redactedContent = block.redactedContent + } + yield { + type: 'modelContentBlockDeltaEvent', + delta, + } + yield { type: 'modelContentBlockStopEvent' } + break + } + + case 'cachePointBlock': + // CachePointBlock doesn't generate delta events + yield { type: 'modelContentBlockStartEvent' } + yield { type: 'modelContentBlockStopEvent' } + break + + case 'toolResultBlock': + // ToolResultBlock appears in user messages and doesn't generate model events + // This shouldn't normally be in assistant messages, but we'll handle it gracefully + break + + case 'guardContentBlock': + // GuardContentBlock is handled by guardrails and doesn't generate model events + // This is typically used in system prompts or message content for guardrail evaluation + break + + case 'citationsBlock': + yield { type: 'modelContentBlockStartEvent' } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { + type: 'citationsDelta', + citations: block.citations, + content: block.content, + }, + } + yield { type: 'modelContentBlockStopEvent' } + break + + case 'imageBlock': + case 'videoBlock': + case 'documentBlock': + case 'jsonBlock': + // These blocks don't generate events in mock - just skip them + break + + default: { + // Exhaustive check + const _exhaustive: never = block + throw new Error(`Unknown content block type: ${(_exhaustive as PlainContentBlock).type}`) + } + } + } +} diff --git a/strands-ts/src/__fixtures__/mock-meter.ts b/strands-ts/src/__fixtures__/mock-meter.ts new file mode 100644 index 0000000000..f4868fc9c6 --- /dev/null +++ b/strands-ts/src/__fixtures__/mock-meter.ts @@ -0,0 +1,64 @@ +/** + * Mock OpenTelemetry Meter for testing metric instrument emission. + * Records all counter and histogram data points for assertion. + */ + +import type { Attributes } from '@opentelemetry/api' + +export interface MockDataPoint { + value: number + attributes: Attributes | undefined +} + +export class MockCounter { + readonly dataPoints: MockDataPoint[] = [] + + add(value: number, attributes?: Attributes): void { + this.dataPoints.push({ value, attributes }) + } + + get sum(): number { + return this.dataPoints.reduce((acc, dp) => acc + dp.value, 0) + } +} + +export class MockHistogram { + readonly dataPoints: MockDataPoint[] = [] + + record(value: number, attributes?: Attributes): void { + this.dataPoints.push({ value, attributes }) + } + + get sum(): number { + return this.dataPoints.reduce((acc, dp) => acc + dp.value, 0) + } +} + +/** + * Mock OTEL Meter that tracks created instruments by name. + * Cast to `Meter` when passing to `vi.spyOn(otelMetrics, 'getMeter')`. + */ +export class MockMeter { + private readonly _counters = new Map() + private readonly _histograms = new Map() + + createCounter(name: string): MockCounter { + const counter = new MockCounter() + this._counters.set(name, counter) + return counter + } + + createHistogram(name: string): MockHistogram { + const histogram = new MockHistogram() + this._histograms.set(name, histogram) + return histogram + } + + getCounter(name: string): MockCounter | undefined { + return this._counters.get(name) + } + + getHistogram(name: string): MockHistogram | undefined { + return this._histograms.get(name) + } +} diff --git a/strands-ts/src/__fixtures__/mock-plugin.ts b/strands-ts/src/__fixtures__/mock-plugin.ts new file mode 100644 index 0000000000..4b36cc5d95 --- /dev/null +++ b/strands-ts/src/__fixtures__/mock-plugin.ts @@ -0,0 +1,56 @@ +import type { HookableEvent } from '../hooks/index.js' +import type { Plugin } from '../plugins/plugin.js' +import type { LocalAgent } from '../types/agent.js' +import { + InitializedEvent, + BeforeInvocationEvent, + AfterInvocationEvent, + MessageAddedEvent, + BeforeToolsEvent, + AfterToolsEvent, + BeforeToolCallEvent, + AfterToolCallEvent, + BeforeModelCallEvent, + AfterModelCallEvent, + ToolResultEvent, + ToolStreamUpdateEvent, +} from '../hooks/index.js' +import type { HookableEventConstructor } from '../hooks/types.js' + +/** + * Mock plugin that records all hookable event invocations for testing. + */ +export class MockPlugin implements Plugin { + invocations: HookableEvent[] = [] + + get name(): string { + return 'mock-plugin' + } + + initAgent(agent: LocalAgent): void { + const eventTypes: HookableEventConstructor[] = [ + InitializedEvent, + BeforeInvocationEvent, + AfterInvocationEvent, + MessageAddedEvent, + BeforeToolsEvent, + AfterToolsEvent, + BeforeToolCallEvent, + AfterToolCallEvent, + BeforeModelCallEvent, + AfterModelCallEvent, + ToolResultEvent, + ToolStreamUpdateEvent, + ] + + for (const eventType of eventTypes) { + agent.addHook(eventType, (e) => { + this.invocations.push(e) + }) + } + } + + reset(): void { + this.invocations = [] + } +} diff --git a/strands-ts/src/__fixtures__/mock-span.ts b/strands-ts/src/__fixtures__/mock-span.ts new file mode 100644 index 0000000000..2a0f454e75 --- /dev/null +++ b/strands-ts/src/__fixtures__/mock-span.ts @@ -0,0 +1,121 @@ +/** + * Mock OpenTelemetry Span for testing tracer functionality. + * Implements the full Span interface and records all calls for assertion. + */ + +import type { + Span, + SpanContext, + SpanStatus, + SpanAttributes, + SpanAttributeValue, + TimeInput, + Exception, + Link, +} from '@opentelemetry/api' + +/** + * Concrete mock implementing the Span interface. + * Chainable methods return `this` to satisfy the `Span` contract. + */ +export class MockSpan implements Span { + readonly calls = { + setAttribute: [] as Array<{ key: string; value: SpanAttributeValue }>, + setAttributes: [] as Array<{ attributes: SpanAttributes }>, + addEvent: [] as Array<{ + name: string + attributes: SpanAttributes | TimeInput | undefined + startTime: TimeInput | undefined + }>, + setStatus: [] as Array<{ status: SpanStatus }>, + updateName: [] as Array<{ name: string }>, + end: [] as Array<{ endTime: TimeInput | undefined }>, + recordException: [] as Array<{ exception: Exception; time: TimeInput | undefined }>, + } + + /** @returns A fixed span context for test assertions. */ + spanContext(): SpanContext { + return { traceId: 'trace-1', spanId: 'span-1', traceFlags: 1 } + } + + /** Records a single attribute. */ + setAttribute(key: string, value: SpanAttributeValue): this { + this.calls.setAttribute.push({ key, value }) + return this + } + + /** Records a batch of attributes. */ + setAttributes(attributes: SpanAttributes): this { + this.calls.setAttributes.push({ attributes }) + for (const [key, value] of Object.entries(attributes)) { + if (value !== undefined) this.setAttribute(key, value) + } + return this + } + + /** Records a span event with optional attributes. */ + addEvent(name: string, attributesOrStartTime?: SpanAttributes | TimeInput, startTime?: TimeInput): this { + this.calls.addEvent.push({ name, attributes: attributesOrStartTime, startTime }) + return this + } + + /** No-op link addition. */ + addLink(_link: Link): this { + return this + } + + /** No-op batch link addition. */ + addLinks(_links: Link[]): this { + return this + } + + /** Records a status change. */ + setStatus(status: SpanStatus): this { + this.calls.setStatus.push({ status }) + return this + } + + /** Records a name update. */ + updateName(name: string): this { + this.calls.updateName.push({ name }) + return this + } + + /** Records span end. */ + end(endTime?: TimeInput): void { + this.calls.end.push({ endTime }) + } + + /** Always returns true for mock spans. */ + isRecording(): boolean { + return true + } + + /** Records an exception. */ + recordException(exception: Exception, time?: TimeInput): void { + this.calls.recordException.push({ exception, time }) + } + + /** + * Get the value of a specific attribute set via setAttribute. + */ + getAttributeValue(key: string): SpanAttributeValue | undefined { + const entry = this.calls.setAttribute.find((c) => c.key === key) + return entry?.value + } + + /** + * Get all events with a given name. + */ + getEvents(name: string): Array<{ name: string; attributes: SpanAttributes | TimeInput | undefined }> { + return this.calls.addEvent.filter((c) => c.name === name) + } +} + +/** + * Extract a string attribute from a mock span event's attributes. + */ +export function eventAttr(event: { attributes: SpanAttributes | TimeInput | undefined }, key: string): string { + const attrs = event.attributes as Record + return attrs[key]! +} diff --git a/strands-ts/src/__fixtures__/mock-storage-provider.ts b/strands-ts/src/__fixtures__/mock-storage-provider.ts new file mode 100644 index 0000000000..b47a57663d --- /dev/null +++ b/strands-ts/src/__fixtures__/mock-storage-provider.ts @@ -0,0 +1,152 @@ +import type { Scope, Snapshot, SnapshotManifest } from '../session/types.js' +import type { SnapshotStorage, SnapshotLocation } from '../session/index.js' + +export function createTestSnapshot(overrides: Partial = {}): Snapshot { + return { + schemaVersion: '1.0', + scope: 'agent', + createdAt: '2024-01-01T00:00:00.000Z', + data: { + messages: [], + state: { testKey: 'testValue' }, + systemPrompt: 'You are a test assistant', + }, + appData: {}, + ...overrides, + } +} + +export function createTestManifest(overrides: Partial = {}): SnapshotManifest { + return { + schemaVersion: '1.0', + updatedAt: '2024-01-01T00:00:00.000Z', + ...overrides, + } +} + +export function createTestScope(kind: 'agent' | 'multiAgent' = 'agent'): Scope { + return kind +} + +export function createTestSnapshots(count: number, baseSnapshot?: Partial): Snapshot[] { + return Array.from({ length: count }, (_, i) => + createTestSnapshot({ + ...baseSnapshot, + createdAt: new Date(2024, 0, 1, 0, i).toISOString(), + }) + ) +} + +/** + * Mock storage implementation for testing that stores data in memory + */ +export class MockSnapshotStorage implements SnapshotStorage { + private snapshots = new Map() + private manifests = new Map() + public shouldThrowErrors = false + + async saveSnapshot(params: { + location: SnapshotLocation + snapshotId: string + isLatest: boolean + snapshot: Snapshot + }): Promise { + if (this.shouldThrowErrors) throw new Error('Mock save error') + + const { location, snapshotId, isLatest, snapshot } = params + const key = this.getKey(location, snapshotId) + this.snapshots.set(key, snapshot) + + if (isLatest) { + this.snapshots.set(this.getKey(location, 'latest'), snapshot) + } + } + + async loadSnapshot(params: { location: SnapshotLocation; snapshotId?: string }): Promise { + if (this.shouldThrowErrors) throw new Error('Mock load error') + + if (params.snapshotId === undefined) { + return this.snapshots.get(this.getKey(params.location, 'latest')) ?? null + } + return this.snapshots.get(this.getKey(params.location, params.snapshotId)) ?? null + } + + async listSnapshotIds(params: { + location: SnapshotLocation + limit?: number + startAfter?: string + }): Promise { + if (this.shouldThrowErrors) throw new Error('Mock list error') + + const prefix = `${params.location.sessionId}::${params.location.scope}::${params.location.scopeId}::` + let ids: string[] = [] + + for (const [key] of this.snapshots) { + if (key.startsWith(prefix) && !key.endsWith('::latest')) { + ids.push(key.slice(prefix.length)) + } + } + + ids = ids.sort() + if (params.startAfter) { + ids = ids.filter((id) => id > params.startAfter!) + } + if (params.limit !== undefined) { + ids = ids.slice(0, params.limit) + } + return ids + } + + async deleteSession(params: { sessionId: string }): Promise { + if (this.shouldThrowErrors) throw new Error('Mock delete error') + + for (const key of this.snapshots.keys()) { + if (key.startsWith(`${params.sessionId}::`)) this.snapshots.delete(key) + } + for (const key of this.manifests.keys()) { + if (key.startsWith(`${params.sessionId}::`)) this.manifests.delete(key) + } + } + + async loadManifest(params: { location: SnapshotLocation }): Promise { + if (this.shouldThrowErrors) throw new Error('Mock manifest load error') + + const { sessionId } = params.location + if (!sessionId) { + throw new Error('Invalid sessionId: cannot be empty or undefined') + } + + const key = this.getManifestKey(params.location) + return ( + this.manifests.get(key) ?? { + schemaVersion: '1.0', + updatedAt: new Date().toISOString(), + } + ) + } + + async saveManifest(params: { location: SnapshotLocation; manifest: SnapshotManifest }): Promise { + if (this.shouldThrowErrors) throw new Error('Mock manifest save error') + + const { sessionId } = params.location + if (!sessionId) { + throw new Error('Invalid sessionId: cannot be empty or undefined') + } + + this.manifests.set(this.getManifestKey(params.location), params.manifest) + } + + private getKey(location: SnapshotLocation, snapshotId: string): string { + if (!location.sessionId) { + throw new Error('Invalid sessionId: cannot be empty or undefined') + } + return `${location.sessionId}::${location.scope}::${location.scopeId}::${snapshotId}` + } + + private getManifestKey(location: SnapshotLocation): string { + if (!location.sessionId) { + throw new Error('Invalid sessionId: cannot be empty or undefined') + } + return `${location.sessionId}::${location.scope}::${location.scopeId}::manifest` + } +} diff --git a/strands-ts/src/__fixtures__/model-test-helpers.ts b/strands-ts/src/__fixtures__/model-test-helpers.ts new file mode 100644 index 0000000000..adfda0e06b --- /dev/null +++ b/strands-ts/src/__fixtures__/model-test-helpers.ts @@ -0,0 +1,99 @@ +/** + * Test fixtures and helpers for Model testing. + * This module provides utilities for testing Model implementations without + * requiring actual API clients. + */ + +import { Model } from '../models/model.js' +import type { Message } from '../types/messages.js' +import type { ModelStreamEvent } from '../models/streaming.js' +import type { BaseModelConfig, StreamOptions } from '../models/model.js' + +/** + * Test model provider that returns a predefined stream of events. + * Useful for testing Model.streamAggregated() and other Model functionality + * without requiring actual API calls. + * + * @example + * ```typescript + * const provider = new TestModelProvider(async function* () { + * yield { type: 'modelMessageStartEvent', role: 'assistant' } + * yield { type: 'modelContentBlockStartEvent' } + * yield { type: 'modelContentBlockDeltaEvent', delta: { type: 'textDelta', text: 'Hello' } } + * yield { type: 'modelContentBlockStopEvent' } + * yield { type: 'modelMessageStopEvent', stopReason: 'endTurn' } + * }) + * + * const message = await collectAggregated(provider.streamAggregated(messages)) + * ``` + */ +export class TestModelProvider extends Model { + private eventGenerator: (() => AsyncGenerator) | undefined + private config: BaseModelConfig = { modelId: 'test-model' } + + constructor(eventGenerator?: () => AsyncGenerator) { + super() + this.eventGenerator = eventGenerator + } + + setEventGenerator(eventGenerator: () => AsyncGenerator): void { + this.eventGenerator = eventGenerator + } + + updateConfig(modelConfig: BaseModelConfig): void { + this.config = { ...this.config, ...modelConfig } + } + + getConfig(): BaseModelConfig { + return this.config + } + + async *stream(_messages: Message[], _options?: StreamOptions): AsyncGenerator { + if (!this.eventGenerator) { + throw new Error('Event generator not set') + } + yield* this.eventGenerator() + } +} + +/** + * Helper function to collect events and result from an async generator. + * Properly handles AsyncGenerator where the final value is returned + * rather than yielded. + * + * @param generator - An async generator that yields items and returns a final result + * @returns Object with items array (yielded values) and result (return value) + */ +export async function collectGenerator( + generator: AsyncGenerator +): Promise<{ items: E[]; result: R }> { + const items: E[] = [] + let done = false + let result: R | undefined + + while (!done) { + const { value, done: isDone } = await generator.next() + done = isDone ?? false + if (!done) { + items.push(value as E) + } else { + result = value as R + } + } + + return { items, result: result! } +} + +/** + * Helper function to collect all items from an async iterator. + * + * @param stream - An async iterable that yields items + * @returns Array of all yielded items + */ +export async function collectIterator(stream: AsyncIterable): Promise { + const items: T[] = [] + for await (const item of stream) { + items.push(item) + } + return items +} diff --git a/strands-ts/src/__fixtures__/slim-types.ts b/strands-ts/src/__fixtures__/slim-types.ts new file mode 100644 index 0000000000..e4f6684ef0 --- /dev/null +++ b/strands-ts/src/__fixtures__/slim-types.ts @@ -0,0 +1,75 @@ +/** + * Utility types for testing that strip methods from classes. + * Allows tests to use plain objects without needing to construct class instances. + */ + +import type { + Message, + ToolResultBlock, + TextBlock, + ToolUseBlock, + ReasoningBlock, + CachePointBlock, + GuardContentBlock, + JsonBlock, +} from '../types/messages.js' +import type { ImageBlock, VideoBlock, DocumentBlock } from '../types/media.js' +import type { CitationsBlock } from '../types/citations.js' + +/** + * Strips the toJSON method from a type, allowing plain objects to be used in tests. + * This is useful when you want to pass plain object literals where class instances are expected. + * + * @example + * ```typescript + * const messages: NoJSON[] = [ + * { type: 'message', role: 'user', content: [{ type: 'textBlock', text: 'Hello' }] } + * ] + * ``` + */ +export type NoJSON = Omit + +/** + * Plain content block without toJSON method - preserves discriminated union. + */ +export type PlainContentBlock = + | NoJSON + | NoJSON + | NoJSON + | NoJSON + | NoJSON + | NoJSON + | NoJSON + | NoJSON + | NoJSON + | NoJSON + | NoJSON + +/** + * Plain system content block without toJSON method. + */ +export type PlainSystemContentBlock = NoJSON | NoJSON | NoJSON + +/** + * Plain tool result block without toJSON method. + */ +export type PlainToolResultBlock = NoJSON + +/** + * Recursively strips toJSON from a type and its nested content. + * Use this for Message which contains ContentBlock arrays. + */ +export type PlainMessage = NoJSON & { content: PlainContentBlock[] } + +/** + * Type assertion helper for using plain message objects where Message[] is expected. + * Use this when calling model.stream() with plain objects in tests. + * + * @example + * ```typescript + * const messages = [ + * { type: 'message', role: 'user', content: [{ type: 'textBlock', text: 'Hello' }] } + * ] as PlainMessage[] as Message[] + * ``` + */ +export type { Message } diff --git a/strands-ts/src/__fixtures__/test-sandbox.node.ts b/strands-ts/src/__fixtures__/test-sandbox.node.ts new file mode 100644 index 0000000000..efd2329b14 --- /dev/null +++ b/strands-ts/src/__fixtures__/test-sandbox.node.ts @@ -0,0 +1,29 @@ +import { PosixShellSandbox } from '../sandbox/posix-shell.js' +import { shellQuote } from '../utils/shell-quote.js' +import { streamProcess } from '../sandbox/stream-process.js' +import type { ExecuteOptions } from '../sandbox/base.js' +import type { ExecutionResult, StreamChunk } from '../sandbox/types.js' + +/** + * Test sandbox that executes commands within a specific working directory. + * + * Extends PosixShellSandbox so it exercises the same code paths real sandboxes + * use: base64 file encoding, shell quoting, ls parsing, etc. + */ +export class TestSandbox extends PosixShellSandbox { + readonly workingDir: string + + constructor(workingDir: string) { + super() + this.workingDir = workingDir + } + + async *executeStreaming( + command: string, + options?: ExecuteOptions + ): AsyncGenerator { + const cwd = options?.cwd ?? this.workingDir + const fullCommand = `cd ${shellQuote(cwd)} && ${command}` + yield* streamProcess('sh', ['-c', fullCommand], { timeout: options?.timeout, signal: options?.signal }) + } +} diff --git a/strands-ts/src/__fixtures__/tool-helpers.ts b/strands-ts/src/__fixtures__/tool-helpers.ts new file mode 100644 index 0000000000..db54ceeee9 --- /dev/null +++ b/strands-ts/src/__fixtures__/tool-helpers.ts @@ -0,0 +1,125 @@ +/** + * Test fixtures and helpers for Tool testing. + * This module provides utilities for testing Tool implementations. + */ + +import type { Tool, ToolContext } from '../tools/tool.js' +import { TextBlock, ToolResultBlock } from '../types/messages.js' +import type { JSONValue } from '../types/json.js' +import { StateStore } from '../state-store.js' +import { ToolRegistry } from '../registry/tool-registry.js' +import type { PlainToolResultBlock } from './slim-types.js' +import type { InvocationState, LocalAgent } from '../types/agent.js' + +/** + * Helper to create a mock ToolContext for testing. + * + * @param toolUse - The tool use request + * @param appState - Optional initial app state + * @param invocationState - Optional initial invocation state + * @returns Mock ToolContext object + */ +export function createMockContext( + toolUse: { name: string; toolUseId: string; input: JSONValue }, + appState?: Record, + invocationState?: InvocationState +): ToolContext { + return { + toolUse, + agent: { + id: 'mock-agent', + appState: new StateStore(appState), + messages: [], + toolRegistry: new ToolRegistry(), + addHook: () => () => {}, + } as unknown as LocalAgent, + invocationState: invocationState ?? {}, + interrupt: (): never => { + throw new Error('interrupt not available in mock context') + }, + } +} + +/** + * Result function type for createMockTool - accepts plain objects or class instances. + * Can optionally receive the ToolContext for interrupt-aware tools. + */ +type ToolResultFn = + | (() => PlainToolResultBlock | AsyncGenerator) + | (( + context: ToolContext + ) => PlainToolResultBlock | AsyncGenerator | string | void) + +/** + * Helper to create a mock tool for testing. + * + * @param name - The name of the mock tool + * @param resultFn - Function that returns a ToolResultBlock (plain object or class instance) or an AsyncGenerator + * @returns Mock Tool object + */ +export function createMockTool(name: string, resultFn: ToolResultFn): Tool { + return { + name, + description: `Mock tool ${name}`, + toolSpec: { + name, + description: `Mock tool ${name}`, + inputSchema: { type: 'object', properties: {} }, + }, + // eslint-disable-next-line require-yield + async *stream(context): AsyncGenerator { + const result = resultFn(context) + if (typeof result === 'string') { + return new ToolResultBlock({ + toolUseId: context.toolUse.toolUseId, + status: 'success', + content: [new TextBlock(result)], + }) + } + if (result === undefined || result === null) { + return new ToolResultBlock({ + toolUseId: context.toolUse.toolUseId, + status: 'success', + content: [], + }) + } + if (typeof result === 'object' && result !== null && Symbol.asyncIterator in result) { + // For generators that throw errors + const gen = result as AsyncGenerator + let done = false + while (!done) { + const { value, done: isDone } = await gen.next() + done = isDone ?? false + if (done) { + return value + } + } + // This should never be reached but TypeScript needs a return + throw new Error('Generator ended unexpectedly') + } else { + return result as ToolResultBlock + } + }, + } +} + +/** + * Helper to create a simple mock tool with minimal configuration for testing. + * This is a lighter-weight version of createMockTool for scenarios where the tool's + * execution behavior is not relevant to the test. + * + * @param name - Optional name of the mock tool (defaults to a random UUID) + * @returns Mock Tool object + */ +export function createRandomTool(name?: string): Tool { + const toolName = name ?? globalThis.crypto.randomUUID() + return createMockTool( + toolName, + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success' as const, + content: [], + }) + ) +} diff --git a/strands-ts/src/__tests__/errors.test.ts b/strands-ts/src/__tests__/errors.test.ts new file mode 100644 index 0000000000..e23d9f9654 --- /dev/null +++ b/strands-ts/src/__tests__/errors.test.ts @@ -0,0 +1,225 @@ +import { describe, it, expect } from 'vitest' +import { + ModelError, + ContextWindowOverflowError, + MaxTokensError, + ModelThrottledError, + normalizeError, +} from '../errors.js' +import { Message, TextBlock } from '../types/messages.js' + +describe('ModelError', () => { + describe('when instantiated with a message', () => { + it('creates an error with the correct message', () => { + const message = 'Model error occurred' + const error = new ModelError(message) + + expect(error.message).toBe(message) + }) + + it('has the correct error name', () => { + const error = new ModelError('test') + + expect(error.name).toBe('ModelError') + }) + + it('is an instance of Error', () => { + const error = new ModelError('test') + + expect(error).toBeInstanceOf(Error) + }) + }) + + describe('when instantiated with a cause', () => { + it('stores the cause error', () => { + const cause = new Error('original error') + const error = new ModelError('wrapped error', { cause }) + + expect(error.message).toBe('wrapped error') + expect(error.cause).toBe(cause) + }) + }) +}) + +describe('ContextWindowOverflowError', () => { + describe('when instantiated with a message', () => { + it('creates an error with the correct message', () => { + const message = 'Context window overflow occurred' + const error = new ContextWindowOverflowError(message) + + expect(error.message).toBe(message) + }) + + it('has the correct error name', () => { + const error = new ContextWindowOverflowError('test') + + expect(error.name).toBe('ContextWindowOverflowError') + }) + + it('is an instance of Error', () => { + const error = new ContextWindowOverflowError('test') + + expect(error).toBeInstanceOf(Error) + }) + + it('is an instance of ModelError', () => { + const error = new ContextWindowOverflowError('test') + + expect(error).toBeInstanceOf(ModelError) + }) + }) +}) + +describe('MaxTokensError', () => { + describe('when instantiated with a message and partial message', () => { + it('creates an error with the correct message', () => { + const partialMessage = new Message({ + role: 'assistant', + content: [new TextBlock('partial response')], + }) + const error = new MaxTokensError('Max tokens reached', partialMessage) + + expect(error.message).toBe('Max tokens reached') + }) + + it('has the correct error name', () => { + const partialMessage = new Message({ + role: 'assistant', + content: [new TextBlock('partial response')], + }) + const error = new MaxTokensError('test', partialMessage) + + expect(error.name).toBe('MaxTokensError') + }) + + it('stores the partial message', () => { + const partialMessage = new Message({ + role: 'assistant', + content: [new TextBlock('partial response')], + }) + const error = new MaxTokensError('Max tokens reached', partialMessage) + + expect(error.partialMessage).toBe(partialMessage) + }) + + it('is an instance of Error', () => { + const partialMessage = new Message({ + role: 'assistant', + content: [new TextBlock('partial response')], + }) + const error = new MaxTokensError('test', partialMessage) + + expect(error).toBeInstanceOf(Error) + }) + + it('is an instance of ModelError', () => { + const partialMessage = new Message({ + role: 'assistant', + content: [new TextBlock('partial response')], + }) + const error = new MaxTokensError('test', partialMessage) + + expect(error).toBeInstanceOf(ModelError) + }) + }) +}) + +describe('ModelThrottledError', () => { + describe('when instantiated with a message', () => { + it('creates an error with the correct message', () => { + const message = 'Rate limit exceeded' + const error = new ModelThrottledError(message) + + expect(error.message).toBe(message) + }) + + it('has the correct error name', () => { + const error = new ModelThrottledError('test') + + expect(error.name).toBe('ModelThrottledError') + }) + + it('is an instance of Error', () => { + const error = new ModelThrottledError('test') + + expect(error).toBeInstanceOf(Error) + }) + + it('is an instance of ModelError', () => { + const error = new ModelThrottledError('test') + + expect(error).toBeInstanceOf(ModelError) + }) + }) + + describe('when instantiated with a cause', () => { + it('preserves the original error as cause', () => { + const originalError = new Error('Original rate limit error') + const error = new ModelThrottledError('Rate limit exceeded', { cause: originalError }) + + expect(error.cause).toBe(originalError) + }) + + it('has undefined cause when not provided', () => { + const error = new ModelThrottledError('Rate limit exceeded') + + expect(error.cause).toBeUndefined() + }) + }) +}) + +describe('normalizeError', () => { + describe('when given an Error instance', () => { + it('returns the same Error instance', () => { + const error = new Error('test error') + const result = normalizeError(error) + + expect(result).toBe(error) + }) + }) + + describe('when given a string', () => { + it('wraps it in an Error', () => { + const result = normalizeError('test error') + + expect(result).toBeInstanceOf(Error) + expect(result.message).toBe('test error') + }) + }) + + describe('when given a number', () => { + it('converts it to string and wraps in Error', () => { + const result = normalizeError(42) + + expect(result).toBeInstanceOf(Error) + expect(result.message).toBe('42') + }) + }) + + describe('when given an object', () => { + it('converts it to string and wraps in Error', () => { + const result = normalizeError({ code: 'ERR_TEST' }) + + expect(result).toBeInstanceOf(Error) + expect(result.message).toBe('[object Object]') + }) + }) + + describe('when given null', () => { + it('converts it to string and wraps in Error', () => { + const result = normalizeError(null) + + expect(result).toBeInstanceOf(Error) + expect(result.message).toBe('null') + }) + }) + + describe('when given undefined', () => { + it('converts it to string and wraps in Error', () => { + const result = normalizeError(undefined) + + expect(result).toBeInstanceOf(Error) + expect(result.message).toBe('undefined') + }) + }) +}) diff --git a/strands-ts/src/__tests__/index.test.ts b/strands-ts/src/__tests__/index.test.ts new file mode 100644 index 0000000000..5dfcd0f96f --- /dev/null +++ b/strands-ts/src/__tests__/index.test.ts @@ -0,0 +1,59 @@ +import { describe, it, expect } from 'vitest' +import * as SDK from '../index.js' + +describe('index', () => { + describe('when importing from main entry point', () => { + it('exports error classes', () => { + expect(SDK.ContextWindowOverflowError).toBeDefined() + }) + + it('exports BedrockModel', () => { + expect(SDK.BedrockModel).toBeDefined() + }) + + it('can instantiate BedrockModel', () => { + const provider = new SDK.BedrockModel({ region: 'us-west-2' }) + expect(provider).toBeInstanceOf(SDK.BedrockModel) + expect(provider.getConfig()).toBeDefined() + }) + + it('exports all required types', () => { + // This test ensures all type exports compile correctly + // If any exports are missing, TypeScript will error + const _typeCheck: { + // Error types + contextError: typeof SDK.ContextWindowOverflowError + // Model provider + provider: typeof SDK.BedrockModel + } = { + contextError: SDK.ContextWindowOverflowError, + provider: SDK.BedrockModel, + } + expect(_typeCheck).toBeDefined() + }) + + it('exports streaming event classes as values, not just types', () => { + // Regression: these must be value exports (not `export type`) so they + // survive TypeScript type-erasure and can be used with `instanceof` / + // `new` at runtime. + expect(SDK.ToolStreamEvent).toBeDefined() + expect(SDK.ModelMessageStartEvent).toBeDefined() + expect(SDK.ModelContentBlockStartEvent).toBeDefined() + expect(SDK.ModelContentBlockDeltaEvent).toBeDefined() + expect(SDK.ModelContentBlockStopEvent).toBeDefined() + expect(SDK.ModelMessageStopEvent).toBeDefined() + expect(SDK.ModelMetadataEvent).toBeDefined() + expect(SDK.ModelRedactionEvent).toBeDefined() + }) + + it('can instantiate exported streaming event classes', () => { + const toolEvent = new SDK.ToolStreamEvent({ data: 'test' }) + expect(toolEvent).toBeInstanceOf(SDK.ToolStreamEvent) + expect(toolEvent.type).toBe('toolStreamEvent') + + const msgStart = new SDK.ModelMessageStartEvent({ type: 'modelMessageStartEvent', role: 'assistant' }) + expect(msgStart).toBeInstanceOf(SDK.ModelMessageStartEvent) + expect(msgStart.type).toBe('modelMessageStartEvent') + }) + }) +}) diff --git a/strands-ts/src/__tests__/interrupt.test.ts b/strands-ts/src/__tests__/interrupt.test.ts new file mode 100644 index 0000000000..a67a32fea8 --- /dev/null +++ b/strands-ts/src/__tests__/interrupt.test.ts @@ -0,0 +1,334 @@ +import { describe, expect, it } from 'vitest' +import { Interrupt, InterruptError, InterruptState, interruptFromAgent } from '../interrupt.js' +import { InterruptResponseContent } from '../types/interrupt.js' + +describe('Interrupt', () => { + it('constructs with all fields and supports response mutation', () => { + const interrupt = new Interrupt({ + id: 'int-1', + name: 'confirm_action', + reason: 'Please confirm', + response: 'approved', + }) + + expect(interrupt).toEqual({ + id: 'int-1', + name: 'confirm_action', + reason: 'Please confirm', + response: 'approved', + source: 'hook', + }) + + // response is mutable after construction + interrupt.response = 'changed' + expect(interrupt.response).toBe('changed') + }) + + it('round-trips through JSON serialization with complex data', () => { + const original = new Interrupt({ + id: 'int-1', + name: 'test', + reason: { complex: { nested: 'data' } }, + response: ['array', 'response'], + }) + + const serialized = JSON.stringify(original) + const deserialized = Interrupt.fromJSON(JSON.parse(serialized)) + + expect(deserialized).toEqual(original) + }) + + it('omits undefined reason/response from toJSON', () => { + const interrupt = new Interrupt({ id: 'int-1', name: 'test' }) + + const json = interrupt.toJSON() + expect(json).toStrictEqual({ id: 'int-1', name: 'test', source: 'hook' }) + expect('reason' in json).toBe(false) + expect('response' in json).toBe(false) + }) +}) + +describe('InterruptError', () => { + it('creates catchable error with single interrupt', () => { + const interrupt = new Interrupt({ id: 'int-1', name: 'confirm_delete' }) + const error = new InterruptError(interrupt) + + expect(error).toBeInstanceOf(Error) + expect(error).toMatchObject({ + name: 'InterruptError', + message: 'Interrupt raised: confirm_delete', + interrupts: [interrupt], + }) + }) + + it('creates error with multiple interrupts', () => { + const a = new Interrupt({ id: 'int-1', name: 'security_check' }) + const b = new Interrupt({ id: 'int-2', name: 'budget_check' }) + const error = new InterruptError([a, b]) + + expect(error).toBeInstanceOf(Error) + expect(error).toMatchObject({ + name: 'InterruptError', + message: '2 interrupts raised: security_check, budget_check', + interrupts: [a, b], + }) + }) +}) + +describe('InterruptState', () => { + describe('getOrCreateInterrupt', () => { + it('creates new interrupt and stores it', () => { + const state = new InterruptState() + + const interrupt = state.getOrCreateInterrupt('int-1', 'test', 'reason') + + expect(interrupt).toEqual({ id: 'int-1', name: 'test', reason: 'reason', source: 'hook' }) + expect(state.interrupts['int-1']).toBe(interrupt) + expect(state.getInterruptsList()).toStrictEqual([interrupt]) + }) + + it('returns existing interrupt by ID without overwriting', () => { + const state = new InterruptState() + const first = state.getOrCreateInterrupt('int-1', 'test', 'reason') + first.response = 'user response' + + const second = state.getOrCreateInterrupt('int-1', 'different', 'different reason') + + expect(second).toBe(first) + expect(second.response).toBe('user response') + }) + + it('creates separate interrupts for different IDs with same name', () => { + const state = new InterruptState() + state.activate() + const first = state.getOrCreateInterrupt('tool:tool-1:0:confirm', 'confirm', 'reason') + first.response = { approved: true } + + const second = state.getOrCreateInterrupt('tool:tool-2:0:confirm', 'confirm', 'reason') + + expect(second).not.toBe(first) + expect(second.id).toBe('tool:tool-2:0:confirm') + expect(second.response).toBeUndefined() + }) + + it('creates interrupt with preemptive response', () => { + const state = new InterruptState() + const interrupt = state.getOrCreateInterrupt('int-1', 'confirm', 'reason', 'pre-approved') + + expect(interrupt).toEqual({ + id: 'int-1', + name: 'confirm', + reason: 'reason', + response: 'pre-approved', + source: 'hook', + }) + }) + + it('ignores preemptive response when interrupt already exists', () => { + const state = new InterruptState() + const first = state.getOrCreateInterrupt('int-1', 'confirm', 'reason') + first.response = 'user response' + + const second = state.getOrCreateInterrupt('int-1', 'confirm', 'reason', 'preemptive') + + expect(second).toBe(first) + expect(second).toEqual({ + id: 'int-1', + name: 'confirm', + reason: 'reason', + response: 'user response', + source: 'hook', + }) + }) + }) + + describe('activate / deactivate', () => { + it('deactivate clears all state', () => { + const state = new InterruptState() + state.getOrCreateInterrupt('int-1', 'test') + state.activate() + expect(state.activated).toBe(true) + + state.deactivate() + + expect(state).toMatchObject({ + interrupts: {}, + resumeResponses: undefined, + activated: false, + }) + }) + }) + + describe('resume', () => { + it('does nothing when not activated', () => { + const state = new InterruptState() + state.getOrCreateInterrupt('int-1', 'test') + + state.resume([new InterruptResponseContent({ interruptId: 'int-1', response: 'yes' })]) + + expect(state.interrupts['int-1']!.response).toBeUndefined() + }) + + it('populates interrupt responses and stores resumeResponses when activated', () => { + const state = new InterruptState() + state.getOrCreateInterrupt('int-1', 'first') + state.getOrCreateInterrupt('int-2', 'second') + state.activate() + + const responses = [ + new InterruptResponseContent({ interruptId: 'int-1', response: 'response1' }), + new InterruptResponseContent({ interruptId: 'int-2', response: { complex: 'data' } }), + ] + state.resume(responses) + + expect(state.interrupts['int-1']).toMatchObject({ response: 'response1' }) + expect(state.interrupts['int-2']).toMatchObject({ response: { complex: 'data' } }) + expect(state.resumeResponses).toBe(responses) + }) + + it('throws error for unknown interrupt ID', () => { + const state = new InterruptState() + state.getOrCreateInterrupt('int-1', 'test') + state.activate() + + expect(() => { + state.resume([new InterruptResponseContent({ interruptId: 'unknown', response: 'yes' })]) + }).toThrow('interrupt_id= | no interrupt found') + }) + }) + + describe('serialization', () => { + it('round-trips through JSON with full state', () => { + const original = new InterruptState() + original.getOrCreateInterrupt('int-1', 'test', { complex: 'reason' }) + original.interrupts['int-1']!.response = ['array', 'response'] + original.activate() + + const serialized = JSON.stringify(original) + const deserialized = InterruptState.fromJSON(JSON.parse(serialized)) + + expect(deserialized.toJSON()).toStrictEqual(original.toJSON()) + }) + + it('round-trips pendingToolExecution through JSON', () => { + const original = new InterruptState() + original.getOrCreateInterrupt('int-1', 'test') + original.activate() + original.setPendingToolExecution({ + assistantMessageData: { + role: 'assistant' as const, + content: [{ toolUse: { name: 'tool', toolUseId: 't-1', input: {} } }], + }, + completedToolResults: { + 't-0': { toolResult: { toolUseId: 't-0', status: 'success' as const, content: [] } }, + }, + }) + + const serialized = JSON.stringify(original) + const deserialized = InterruptState.fromJSON(JSON.parse(serialized)) + + expect(deserialized.toJSON()).toStrictEqual(original.toJSON()) + expect(deserialized.pendingToolExecution).toStrictEqual(original.pendingToolExecution) + }) + + it('deserializes state with resumeResponses', () => { + const state = InterruptState.fromJSON({ + interrupts: { + 'int-1': { id: 'int-1', name: 'test', reason: 'reason', response: 'yes' }, + }, + resumeResponses: [{ interruptResponse: { interruptId: 'int-1', response: 'yes' } }], + activated: true, + }) + + expect(state).toMatchObject({ + activated: true, + interrupts: { + 'int-1': { id: 'int-1', name: 'test', reason: 'reason', response: 'yes' }, + }, + resumeResponses: [{ interruptResponse: { interruptId: 'int-1', response: 'yes' } }], + }) + }) + }) +}) + +describe('interruptFromAgent', () => { + // Minimal agent-like object with _interruptState + function createMockAgent(state: InterruptState) { + return { _interruptState: state } as unknown as import('../types/agent.js').LocalAgent + } + + it('returns preemptive response immediately without throwing', () => { + const state = new InterruptState() + const agent = createMockAgent(state) + + const result = interruptFromAgent( + agent, + 'test-id', + { + name: 'confirm', + reason: 'need approval', + response: 'pre-approved', + }, + 'tool' + ) + + expect(result).toBe('pre-approved') + expect(state.interrupts['test-id']).toMatchObject({ + id: 'test-id', + name: 'confirm', + reason: 'need approval', + response: 'pre-approved', + source: 'tool', + }) + }) + + it('returns resume response over preemptive response for existing interrupt', () => { + const state = new InterruptState() + state.getOrCreateInterrupt('test-id', 'confirm', 'need approval') + state.interrupts['test-id']!.response = 'user-provided' + + const agent = createMockAgent(state) + + const result = interruptFromAgent( + agent, + 'test-id', + { + name: 'confirm', + reason: 'need approval', + response: 'preemptive', + }, + 'tool' + ) + + expect(result).toBe('user-provided') + expect(state.interrupts['test-id']).toMatchObject({ + id: 'test-id', + name: 'confirm', + reason: 'need approval', + response: 'user-provided', + }) + }) + + it('does not interrupt when preemptive response is null', () => { + const state = new InterruptState() + const agent = createMockAgent(state) + + const result = interruptFromAgent( + agent, + 'test-id', + { + name: 'confirm', + response: null, + }, + 'tool' + ) + + expect(result).toBeNull() + expect(state.interrupts['test-id']).toMatchObject({ + id: 'test-id', + name: 'confirm', + response: null, + source: 'tool', + }) + }) +}) diff --git a/strands-ts/src/__tests__/mcp-config.test.node.ts b/strands-ts/src/__tests__/mcp-config.test.node.ts new file mode 100644 index 0000000000..095311dc59 --- /dev/null +++ b/strands-ts/src/__tests__/mcp-config.test.node.ts @@ -0,0 +1,404 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js' +import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' +import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js' +import { McpClient } from '../mcp.js' + +vi.mock('node:fs/promises', () => ({ + readFile: vi.fn(), +})) + +vi.mock('node:os', () => ({ + homedir: vi.fn(() => '/home/user'), +})) + +vi.mock('node:path', () => ({ + join: (...segments: string[]) => segments.join('/'), +})) + +vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => ({ + StdioClientTransport: vi.fn(function () {}), + getDefaultEnvironment: vi.fn(() => ({ PATH: '/usr/bin', HOME: '/home/user' })), +})) + +vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => ({ + StreamableHTTPClientTransport: vi.fn(function () {}), +})) + +vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => ({ + SSEClientTransport: vi.fn(function () {}), +})) + +vi.mock('@modelcontextprotocol/sdk/client/index.js', () => ({ + Client: vi.fn(function (this: Record) { + this.connect = vi.fn() + this.close = vi.fn() + this.listTools = vi.fn() + this.callTool = vi.fn() + this.setRequestHandler = vi.fn() + this.setNotificationHandler = vi.fn() + this.getServerCapabilities = vi.fn() + this.getServerVersion = vi.fn() + this.getInstructions = vi.fn() + this.experimental = { tasks: { callToolStream: vi.fn() } } + }), +})) + +describe('McpClient.loadServers', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('transport detection', () => { + it('creates StdioClientTransport when command is present', async () => { + const clients = await McpClient.loadServers({ + 'my-server': { command: 'node', args: ['server.js'] }, + }) + + expect(clients).toHaveLength(1) + expect(StdioClientTransport).toHaveBeenCalledWith({ + command: 'node', + args: ['server.js'], + }) + }) + + it('creates McpClient with url when url is present', async () => { + const clients = await McpClient.loadServers({ + 'remote-server': { url: 'https://example.com/mcp' }, + }) + + expect(clients).toHaveLength(1) + expect(StdioClientTransport).not.toHaveBeenCalled() + expect(SSEClientTransport).not.toHaveBeenCalled() + }) + + it('creates SSEClientTransport when transport is "sse"', async () => { + const clients = await McpClient.loadServers({ + 'sse-server': { url: 'https://example.com/sse', transport: 'sse' }, + }) + + expect(clients).toHaveLength(1) + expect(SSEClientTransport).toHaveBeenCalledWith(new URL('https://example.com/sse'), undefined) + }) + + it('explicit transport overrides auto-detection', async () => { + const clients = await McpClient.loadServers({ + server: { url: 'https://example.com/mcp', transport: 'sse' }, + }) + + expect(clients).toHaveLength(1) + expect(SSEClientTransport).toHaveBeenCalled() + expect(StreamableHTTPClientTransport).not.toHaveBeenCalled() + }) + + it('interpolates auth credentials for streamable-http', async () => { + vi.stubEnv('CLIENT_ID', 'my-id') + vi.stubEnv('CLIENT_SECRET', 'my-secret') + + const clients = await McpClient.loadServers({ + server: { + url: 'https://example.com/mcp', + auth: { clientId: '${CLIENT_ID}', clientSecret: '${CLIENT_SECRET}' }, + }, + }) + + expect(clients).toHaveLength(1) + }) + + it('throws when auth credential references missing env var', async () => { + vi.unstubAllEnvs() + + await expect( + McpClient.loadServers({ + server: { + url: 'https://example.com/mcp', + auth: { clientId: '${MISSING_ID}', clientSecret: 'literal' }, + }, + }) + ).rejects.toThrow('Environment variable "MISSING_ID" is not set') + }) + }) + + describe('env interpolation', () => { + it('interpolates ${VAR} in env values', async () => { + vi.stubEnv('MY_SECRET', 'secret123') + + await McpClient.loadServers({ + server: { command: 'node', env: { SECRET: '${MY_SECRET}' } }, + }) + + expect(StdioClientTransport).toHaveBeenCalledWith({ + command: 'node', + env: { PATH: '/usr/bin', HOME: '/home/user', SECRET: 'secret123' }, + }) + }) + + it('interpolates ${VAR} in headers', async () => { + vi.stubEnv('TOKEN', 'abc') + + await McpClient.loadServers({ + server: { url: 'https://example.com/sse', transport: 'sse', headers: { Authorization: 'Bearer ${TOKEN}' } }, + }) + + expect(SSEClientTransport).toHaveBeenCalledWith(new URL('https://example.com/sse'), { + requestInit: { headers: { Authorization: 'Bearer abc' } }, + }) + }) + + it('interpolates ${VAR} in url', async () => { + vi.stubEnv('HOST', 'myhost.com') + + const clients = await McpClient.loadServers({ + server: { url: 'https://${HOST}/mcp', transport: 'sse' }, + }) + + expect(clients).toHaveLength(1) + expect(SSEClientTransport).toHaveBeenCalledWith(new URL('https://myhost.com/mcp'), undefined) + }) + + it('throws when env var is not set', async () => { + vi.unstubAllEnvs() + + await expect( + McpClient.loadServers({ + server: { command: 'node', env: { VAL: '${NONEXISTENT_VAR}' } }, + }) + ).rejects.toThrow('Environment variable "NONEXISTENT_VAR" is not set') + }) + + it('skips server with missing env var when continueOnError is true', async () => { + vi.unstubAllEnvs() + + const clients = await McpClient.loadServers({ + broken: { command: 'node', env: { VAL: '${NONEXISTENT_VAR}' }, continueOnError: true }, + working: { command: 'node' }, + }) + + expect(clients).toHaveLength(1) + expect(clients[0]!.clientName).toBe('working') + }) + + it('interpolates ${VAR} in cwd', async () => { + vi.stubEnv('PROJECT_DIR', '/home/user/projects') + + await McpClient.loadServers({ + server: { command: 'node', cwd: '${PROJECT_DIR}/my-server' }, + }) + + expect(StdioClientTransport).toHaveBeenCalledWith({ + command: 'node', + cwd: '/home/user/projects/my-server', + }) + }) + + it('merges env with default environment', async () => { + await McpClient.loadServers({ + server: { command: 'node', env: { CUSTOM: 'value' } }, + }) + + expect(StdioClientTransport).toHaveBeenCalledWith({ + command: 'node', + env: { PATH: '/usr/bin', HOME: '/home/user', CUSTOM: 'value' }, + }) + }) + }) + + describe('file config loading', () => { + it('reads and parses a JSON file', async () => { + const { readFile } = await import('node:fs/promises') + vi.mocked(readFile).mockResolvedValue( + JSON.stringify({ + 'my-server': { command: 'node', args: ['server.js'] }, + }) + ) + + const clients = await McpClient.loadServers('/path/to/config.json') + + expect(readFile).toHaveBeenCalledWith('/path/to/config.json', 'utf-8') + expect(clients).toHaveLength(1) + }) + + it('extracts mcpServers key when present', async () => { + const { readFile } = await import('node:fs/promises') + vi.mocked(readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + 'server-a': { command: 'node' }, + 'server-b': { url: 'https://example.com' }, + }, + }) + ) + + const clients = await McpClient.loadServers('/path/to/config.json') + + expect(clients).toHaveLength(2) + }) + + it('uses whole object when mcpServers key is absent', async () => { + const { readFile } = await import('node:fs/promises') + vi.mocked(readFile).mockResolvedValue( + JSON.stringify({ + 'server-a': { command: 'node' }, + }) + ) + + const clients = await McpClient.loadServers('/path/to/config.json') + + expect(clients).toHaveLength(1) + }) + + it('expands ~ to home directory', async () => { + const { readFile } = await import('node:fs/promises') + vi.mocked(readFile).mockResolvedValue( + JSON.stringify({ + server: { command: 'node' }, + }) + ) + + await McpClient.loadServers('~/config/mcp.json') + + expect(readFile).toHaveBeenCalledWith('/home/user/config/mcp.json', 'utf-8') + }) + }) + + describe('defaults and per-server overrides', () => { + it('per-server continueOnError overrides defaults', async () => { + const clients = await McpClient.loadServers( + { + 'strict-server': { command: 'node', continueOnError: false }, + 'lenient-server': { command: 'node', continueOnError: true }, + }, + { continueOnError: true } + ) + + expect(clients[0]!.continueOnError).toBe(false) + expect(clients[1]!.continueOnError).toBe(true) + }) + + it('applies default continueOnError when server does not override', async () => { + const clients = await McpClient.loadServers({ server: { command: 'node' } }, { continueOnError: true }) + + expect(clients[0]!.continueOnError).toBe(true) + }) + + it('uses server name as applicationName when not in defaults', async () => { + const clients = await McpClient.loadServers({ + 'my-named-server': { command: 'node' }, + }) + + expect(clients[0]!.clientName).toBe('my-named-server') + }) + + it('uses defaults applicationName over server name', async () => { + const clients = await McpClient.loadServers({ server: { command: 'node' } }, { applicationName: 'my-app' }) + + expect(clients[0]!.clientName).toBe('my-app') + }) + }) + + describe('error cases', () => { + it('throws when server has neither command nor url', async () => { + await expect(McpClient.loadServers({ bad: {} })).rejects.toThrow( + 'Server config must include either "command" (stdio) or "url" (http)' + ) + }) + + it('throws when stdio transport specified without command', async () => { + await expect(McpClient.loadServers({ bad: { transport: 'stdio' } })).rejects.toThrow( + 'Stdio transport requires "command" field' + ) + }) + + it('throws when streamable-http transport specified without url', async () => { + await expect(McpClient.loadServers({ bad: { transport: 'streamable-http' } })).rejects.toThrow( + 'Streamable HTTP transport requires "url" field' + ) + }) + + it('throws when sse transport specified without url', async () => { + await expect(McpClient.loadServers({ bad: { transport: 'sse' } })).rejects.toThrow( + 'SSE transport requires "url" field' + ) + }) + + it('throws on invalid file path', async () => { + const { readFile } = await import('node:fs/promises') + vi.mocked(readFile).mockRejectedValue(new Error('ENOENT: no such file or directory')) + + await expect(McpClient.loadServers('/nonexistent/path.json')).rejects.toThrow('ENOENT') + }) + + it('throws on malformed JSON', async () => { + const { readFile } = await import('node:fs/promises') + vi.mocked(readFile).mockResolvedValue('not json{{{') + + await expect(McpClient.loadServers('/path/to/bad.json')).rejects.toThrow() + }) + + it('throws when auth is used with sse transport', async () => { + await expect( + McpClient.loadServers({ + server: { + url: 'https://example.com', + transport: 'sse', + auth: { clientId: 'id', clientSecret: 'secret' }, + }, + }) + ).rejects.toThrow('SSE transport does not support auth') + }) + + it('throws on invalid config shape', async () => { + const { readFile } = await import('node:fs/promises') + vi.mocked(readFile).mockResolvedValue(JSON.stringify([1, 2, 3])) + + await expect(McpClient.loadServers('/path/to/bad.json')).rejects.toThrow('MCP config must be a JSON object') + }) + + it('throws when server has both command and url without explicit transport', async () => { + await expect(McpClient.loadServers({ bad: { command: 'node', url: 'https://example.com' } })).rejects.toThrow( + 'Server config has both "command" and "url"' + ) + }) + }) + + describe('disabled', () => { + it('skips disabled servers', async () => { + const clients = await McpClient.loadServers({ + active: { command: 'node' }, + inactive: { command: 'node', disabled: true }, + }) + + expect(clients).toHaveLength(1) + expect(clients[0]!.clientName).toBe('active') + }) + }) + + describe('env interpolation syntax', () => { + it('supports ${env:VAR} namespaced syntax', async () => { + vi.stubEnv('MY_TOKEN', 'token123') + + await McpClient.loadServers({ + server: { command: 'node', env: { TOKEN: '${env:MY_TOKEN}' } }, + }) + + expect(StdioClientTransport).toHaveBeenCalledWith({ + command: 'node', + env: { PATH: '/usr/bin', HOME: '/home/user', TOKEN: 'token123' }, + }) + }) + + it('interpolates ${VAR} in command and args', async () => { + vi.stubEnv('MY_CMD', '/usr/local/bin/server') + vi.stubEnv('MY_ARG', '3000') + + await McpClient.loadServers({ + server: { command: '${MY_CMD}', args: ['--port=${MY_ARG}'] }, + }) + + expect(StdioClientTransport).toHaveBeenCalledWith({ + command: '/usr/local/bin/server', + args: ['--port=3000'], + }) + }) + }) +}) diff --git a/strands-ts/src/__tests__/mcp.test.ts b/strands-ts/src/__tests__/mcp.test.ts new file mode 100644 index 0000000000..92516f4dd7 --- /dev/null +++ b/strands-ts/src/__tests__/mcp.test.ts @@ -0,0 +1,1348 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { Client } from '@modelcontextprotocol/sdk/client/index.js' +import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' +import { + McpError, + ErrorCode, + ElicitRequestSchema, + UrlElicitationRequiredError, +} from '@modelcontextprotocol/sdk/types.js' +import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' +import { ClientCredentialsProvider } from '@modelcontextprotocol/sdk/client/auth-extensions.js' +import { McpClient } from '../mcp.js' +import { McpTool } from '../tools/mcp-tool.js' +import { JsonBlock, type TextBlock, type ToolResultBlock } from '../types/messages.js' +import { ImageBlock } from '../types/media.js' +import type { LocalAgent } from '../types/agent.js' +import type { ToolContext } from '../tools/tool.js' +import type { ElicitationCallback } from '../types/elicitation.js' +import { context, propagation, trace, TraceFlags } from '@opentelemetry/api' +import type { SpanContext } from '@opentelemetry/api' +import { logger } from '../logging/index.js' +import type { LoggingMessageNotificationParams } from '@modelcontextprotocol/sdk/types.js' + +/** + * Helper to create a mock async generator that yields a result message. + * This simulates the behavior of callToolStream returning a stream that ends with a result. + */ +function createMockCallToolStream(result: unknown) { + return async function* () { + yield { type: 'result', result } + } +} + +vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => ({ + StreamableHTTPClientTransport: vi.fn(function () { + return { start: vi.fn(), send: vi.fn(), close: vi.fn() } + }), +})) + +vi.mock('@modelcontextprotocol/sdk/client/auth-extensions.js', () => ({ + ClientCredentialsProvider: vi.fn(function () { + return { redirectUrl: undefined, clientMetadata: { client_id: 'test' } } + }), +})) + +vi.mock('@modelcontextprotocol/sdk/client/index.js', () => ({ + Client: vi.fn(function () { + return { + connect: vi.fn(), + close: vi.fn(), + listTools: vi.fn(), + callTool: vi.fn(), + setRequestHandler: vi.fn(), + setNotificationHandler: vi.fn(), + getServerCapabilities: vi.fn(), + getServerVersion: vi.fn(), + getInstructions: vi.fn(), + experimental: { + tasks: { + callToolStream: vi.fn(), + }, + }, + } + }), +})) + +vi.mock('../tools/tool.js', () => ({ + Tool: class {}, + createErrorResult: (err: unknown, toolUseId: string) => ({ + type: 'toolResultBlock', + status: 'error', + toolUseId, + content: [{ type: 'textBlock', text: err instanceof Error ? err.message : String(err) }], + }), +})) + +/** + * Executes a tool stream to completion and returns the final result. + */ +async function runTool(gen: AsyncGenerator): Promise { + let result = await gen.next() + while (!result.done) { + result = await gen.next() + } + return result.value as T +} + +/** + * Mock an active span with a valid trace ID via trace.getSpan, + * and stub propagation.inject to populate the carrier with a traceparent. + */ +function mockActiveSpan(traceId: string = '1234567890abcdef1234567890abcdef', traceFlags = TraceFlags.SAMPLED): void { + const mockSpan = { + spanContext: () => + ({ + traceId, + spanId: '1234567890abcdef', + traceFlags, + }) as SpanContext, + } + vi.spyOn(trace, 'getSpan').mockReturnValue(mockSpan as unknown as ReturnType) + vi.spyOn(propagation, 'inject').mockImplementation((_context, carrier) => { + if (carrier && typeof carrier === 'object') { + ;(carrier as Record).traceparent = `00-${traceId}-1234567890abcdef-01` + } + }) +} + +const mockTransport = { + connect: vi.fn(), + close: vi.fn(), + send: vi.fn(), +} as unknown as Transport + +describe('MCP Integration', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + function createElicitationClient(callback: ElicitationCallback) { + const resultsLengthBefore = vi.mocked(Client).mock.results.length + const elicitClient = new McpClient({ + applicationName: 'TestApp', + transport: mockTransport, + elicitationCallback: callback, + }) + const elicitSdkClientMock = vi.mocked(Client).mock.results[resultsLengthBefore]!.value + return { elicitClient, elicitSdkClientMock } + } + + async function connectAndGetElicitationHandler(callback: ElicitationCallback) { + const { elicitClient, elicitSdkClientMock } = createElicitationClient(callback) + await elicitClient.connect() + const handler = elicitSdkClientMock.setRequestHandler.mock.calls[0]![1] + return { handler, elicitSdkClientMock } + } + + describe('McpClient', () => { + let client: McpClient + let sdkClientMock: { + connect: ReturnType + close: ReturnType + listTools: ReturnType + callTool: ReturnType + setRequestHandler: ReturnType + setNotificationHandler: ReturnType + getServerCapabilities: ReturnType + getServerVersion: ReturnType + getInstructions: ReturnType + experimental: { tasks: { callToolStream: ReturnType } } + } + + beforeEach(() => { + client = new McpClient({ + applicationName: 'TestApp', + transport: mockTransport, + }) + sdkClientMock = vi.mocked(Client).mock.results[0]!.value + }) + + it('initializes SDK client with correct configuration', () => { + expect(Client).toHaveBeenCalledWith( + { name: 'TestApp', version: '0.0.1' }, + expect.objectContaining({ + listChanged: expect.objectContaining({ + tools: expect.objectContaining({ autoRefresh: false, debounceMs: 300 }), + }), + }) + ) + }) + + it('injects trace context into tool arguments when active span exists', async () => { + mockActiveSpan() + const tool = new McpTool({ name: 'calc', description: '', inputSchema: {}, client }) + sdkClientMock.callTool.mockResolvedValue({ content: [] }) + + await client.callTool(tool, { op: 'add' }) + + const callArgs = sdkClientMock.callTool.mock.calls[0]![0] + expect(callArgs.arguments).toStrictEqual({ + op: 'add', + _meta: { traceparent: '00-1234567890abcdef1234567890abcdef-1234567890abcdef-01' }, + }) + }) + + it('merges trace context with existing _meta field', async () => { + mockActiveSpan() + const tool = new McpTool({ name: 'calc', description: '', inputSchema: {}, client }) + sdkClientMock.callTool.mockResolvedValue({ content: [] }) + + await client.callTool(tool, { op: 'add', _meta: { progressToken: 'tok-1' } }) + + const callArgs = sdkClientMock.callTool.mock.calls[0]![0] + expect(callArgs.arguments).toStrictEqual({ + op: 'add', + _meta: { + progressToken: 'tok-1', + traceparent: '00-1234567890abcdef1234567890abcdef-1234567890abcdef-01', + }, + }) + }) + + it('passes args unchanged when no active span exists', async () => { + const tool = new McpTool({ name: 'calc', description: '', inputSchema: {}, client }) + sdkClientMock.callTool.mockResolvedValue({ content: [] }) + + await client.callTool(tool, { op: 'add' }) + + const callArgs = sdkClientMock.callTool.mock.calls[0]![0] + expect(callArgs.arguments).toStrictEqual({ op: 'add' }) + }) + + it('passes args unchanged when span has empty trace ID', async () => { + mockActiveSpan('', TraceFlags.NONE) + const tool = new McpTool({ name: 'calc', description: '', inputSchema: {}, client }) + sdkClientMock.callTool.mockResolvedValue({ content: [] }) + + await client.callTool(tool, { op: 'add' }) + + const callArgs = sdkClientMock.callTool.mock.calls[0]![0] + expect(callArgs.arguments).toStrictEqual({ op: 'add' }) + }) + + it('passes args unchanged when context injection fails', async () => { + vi.spyOn(context, 'active').mockImplementation(() => { + throw new Error('Context error') + }) + const tool = new McpTool({ name: 'calc', description: '', inputSchema: {}, client }) + sdkClientMock.callTool.mockResolvedValue({ content: [] }) + + await client.callTool(tool, { op: 'add' }) + + const callArgs = sdkClientMock.callTool.mock.calls[0]![0] + expect(callArgs.arguments).toStrictEqual({ op: 'add' }) + }) + + it('skips trace context injection when disableMcpInstrumentation is true', async () => { + mockActiveSpan() + const noInstrClient = new McpClient({ + applicationName: 'TestApp', + transport: mockTransport, + disableMcpInstrumentation: true, + }) + const noInstrSdkMock = vi.mocked(Client).mock.results.at(-1)!.value + noInstrSdkMock.callTool.mockResolvedValue({ content: [] }) + + const tool = new McpTool({ name: 'calc', description: '', inputSchema: {}, client: noInstrClient }) + + await noInstrClient.callTool(tool, { op: 'add' }) + + const callArgs = noInstrSdkMock.callTool.mock.calls[0]![0] + expect(callArgs.arguments).toStrictEqual({ op: 'add' }) + }) + + it('manages connection state lazily', async () => { + await client.connect() + expect(sdkClientMock.connect).toHaveBeenCalledTimes(1) + + await client.connect() + expect(sdkClientMock.connect).toHaveBeenCalledTimes(1) + }) + + it('supports forced reconnection', async () => { + await client.connect() + await client.connect(true) + + expect(sdkClientMock.close).toHaveBeenCalled() + expect(sdkClientMock.connect).toHaveBeenCalledTimes(2) + }) + + it('converts SDK tool specs to McpTool instances', async () => { + sdkClientMock.listTools.mockResolvedValue({ + tools: [{ name: 'weather', description: 'Get weather', inputSchema: {} }], + }) + + const tools = await client.listTools() + + expect(sdkClientMock.connect).toHaveBeenCalled() + expect(tools).toHaveLength(1) + expect(tools[0]).toBeInstanceOf(McpTool) + expect(tools[0]!.name).toBe('weather') + }) + + it('paginates through all pages of tools', async () => { + sdkClientMock.listTools + .mockResolvedValueOnce({ + tools: [{ name: 'tool_a', description: 'A', inputSchema: {} }], + nextCursor: 'page2', + }) + .mockResolvedValueOnce({ + tools: [{ name: 'tool_b', description: 'B', inputSchema: {} }], + nextCursor: 'page3', + }) + .mockResolvedValueOnce({ + tools: [{ name: 'tool_c', description: 'C', inputSchema: {} }], + }) + + const tools = await client.listTools() + + expect(tools).toHaveLength(3) + expect(tools.map((t) => t.name)).toEqual(['tool_a', 'tool_b', 'tool_c']) + expect(sdkClientMock.listTools).toHaveBeenCalledTimes(3) + expect(sdkClientMock.listTools).toHaveBeenNthCalledWith(1, undefined) + expect(sdkClientMock.listTools).toHaveBeenNthCalledWith(2, { cursor: 'page2' }) + expect(sdkClientMock.listTools).toHaveBeenNthCalledWith(3, { cursor: 'page3' }) + }) + + it('generates description fallback when description is missing', async () => { + sdkClientMock.listTools.mockResolvedValue({ + tools: [{ name: 'my_tool', inputSchema: {} }], + }) + + const tools = await client.listTools() + + expect(tools[0]!.description).toBe('Tool which performs my_tool') + }) + + it('generates description fallback when description is empty string', async () => { + sdkClientMock.listTools.mockResolvedValue({ + tools: [{ name: 'my_tool', description: '', inputSchema: {} }], + }) + + const tools = await client.listTools() + + expect(tools[0]!.description).toBe('Tool which performs my_tool') + }) + + it('uses callTool when tasksConfig is undefined (default)', async () => { + const tool = new McpTool({ name: 'calc', description: '', inputSchema: {}, client }) + sdkClientMock.callTool.mockResolvedValue({ content: [] }) + + await client.callTool(tool, { op: 'add' }) + + expect(sdkClientMock.connect).toHaveBeenCalled() + expect(sdkClientMock.callTool).toHaveBeenCalledWith( + { name: 'calc', arguments: { op: 'add' } }, + undefined, + undefined + ) + expect(sdkClientMock.experimental.tasks.callToolStream).not.toHaveBeenCalled() + }) + + it('forwards abort signal to SDK callTool', async () => { + const tool = new McpTool({ name: 'calc', description: '', inputSchema: {}, client }) + sdkClientMock.callTool.mockResolvedValue({ content: [] }) + const controller = new AbortController() + + await client.callTool(tool, { op: 'add' }, { signal: controller.signal }) + + expect(sdkClientMock.callTool).toHaveBeenCalledWith({ name: 'calc', arguments: { op: 'add' } }, undefined, { + signal: controller.signal, + }) + }) + + it('forwards abort signal to callToolStream when tasksConfig is provided', async () => { + const resultsLengthBefore = vi.mocked(Client).mock.results.length + const taskClient = new McpClient({ + applicationName: 'TestApp', + transport: mockTransport, + tasksConfig: {}, + }) + const taskSdkClientMock = vi.mocked(Client).mock.results[resultsLengthBefore]!.value + const tool = new McpTool({ name: 'calc', description: '', inputSchema: {}, client: taskClient }) + taskSdkClientMock.experimental.tasks.callToolStream.mockReturnValue(createMockCallToolStream({ content: [] })()) + const controller = new AbortController() + + await taskClient.callTool(tool, { op: 'add' }, { signal: controller.signal }) + + expect(taskSdkClientMock.experimental.tasks.callToolStream).toHaveBeenCalledWith( + { name: 'calc', arguments: { op: 'add' } }, + undefined, + { timeout: 60000, maxTotalTimeout: 300000, resetTimeoutOnProgress: true, signal: controller.signal } + ) + }) + + it('uses callToolStream when tasksConfig is provided (empty object)', async () => { + const resultsLengthBefore = vi.mocked(Client).mock.results.length + const taskClient = new McpClient({ + applicationName: 'TestApp', + transport: mockTransport, + tasksConfig: {}, + }) + const taskSdkClientMock = vi.mocked(Client).mock.results[resultsLengthBefore]!.value + const tool = new McpTool({ name: 'calc', description: '', inputSchema: {}, client: taskClient }) + taskSdkClientMock.experimental.tasks.callToolStream.mockReturnValue(createMockCallToolStream({ content: [] })()) + + await taskClient.callTool(tool, { op: 'add' }) + + expect(taskSdkClientMock.connect).toHaveBeenCalled() + expect(taskSdkClientMock.experimental.tasks.callToolStream).toHaveBeenCalledWith( + { name: 'calc', arguments: { op: 'add' } }, + undefined, + { timeout: 60000, maxTotalTimeout: 300000, resetTimeoutOnProgress: true } + ) + expect(taskSdkClientMock.callTool).not.toHaveBeenCalled() + }) + + it('passes custom TTL and pollTimeout to callToolStream', async () => { + const resultsLengthBefore = vi.mocked(Client).mock.results.length + const taskClient = new McpClient({ + applicationName: 'TestApp', + transport: mockTransport, + tasksConfig: { ttl: 30000, pollTimeout: 120000 }, + }) + const taskSdkClientMock = vi.mocked(Client).mock.results[resultsLengthBefore]!.value + const tool = new McpTool({ name: 'calc', description: '', inputSchema: {}, client: taskClient }) + taskSdkClientMock.experimental.tasks.callToolStream.mockReturnValue(createMockCallToolStream({ content: [] })()) + + await taskClient.callTool(tool, { op: 'add' }) + + expect(taskSdkClientMock.experimental.tasks.callToolStream).toHaveBeenCalledWith( + { name: 'calc', arguments: { op: 'add' } }, + undefined, + { timeout: 30000, maxTotalTimeout: 120000, resetTimeoutOnProgress: true } + ) + }) + + it('validates tool arguments', async () => { + const tool = new McpTool({ name: 't', description: '', inputSchema: {}, client }) + await expect(client.callTool(tool, ['invalid-array'])).rejects.toThrow(/JSON Object/) + }) + + it('cleans up resources', async () => { + await client.disconnect() + expect(sdkClientMock.close).toHaveBeenCalled() + expect(mockTransport.close).toHaveBeenCalled() + }) + + it('supports Symbol.asyncDispose for await using pattern', async () => { + await client[Symbol.asyncDispose]() + expect(sdkClientMock.close).toHaveBeenCalled() + expect(mockTransport.close).toHaveBeenCalled() + }) + + it('registers elicitation handler before connecting when callback is provided', async () => { + const resultsLengthBefore = vi.mocked(Client).mock.results.length + const callback: ElicitationCallback = vi.fn() + const elicitClient = new McpClient({ + applicationName: 'TestApp', + transport: mockTransport, + elicitationCallback: callback, + }) + const elicitSdkClientMock = vi.mocked(Client).mock.results[resultsLengthBefore]!.value + + await elicitClient.connect() + + expect(elicitSdkClientMock.setRequestHandler).toHaveBeenCalledWith(ElicitRequestSchema, expect.any(Function)) + const setHandlerOrder = elicitSdkClientMock.setRequestHandler.mock.invocationCallOrder[0]! + const connectOrder = elicitSdkClientMock.connect.mock.invocationCallOrder[0]! + expect(setHandlerOrder).toBeLessThan(connectOrder) + }) + + it('does not register elicitation handler when no callback is provided', async () => { + await client.connect() + + expect(sdkClientMock.setRequestHandler).not.toHaveBeenCalled() + }) + + it('passes elicitation capabilities to Client when callback is provided', () => { + const callback: ElicitationCallback = vi.fn() + new McpClient({ + applicationName: 'TestApp', + transport: mockTransport, + elicitationCallback: callback, + }) + + const lastCall = vi.mocked(Client).mock.calls.at(-1)! + expect(lastCall[1]).toEqual(expect.objectContaining({ capabilities: { elicitation: { form: {}, url: {} } } })) + }) + + it('elicitation handler returns accepted result with content', async () => { + const callbackResult = { action: 'accept' as const, content: { username: 'alice' } } + const callback: ElicitationCallback = vi.fn().mockResolvedValue(callbackResult) + const { handler } = await connectAndGetElicitationHandler(callback) + const request = { + method: 'elicitation/create', + params: { message: 'Enter username', requestedSchema: { type: 'object' } }, + } + const extra = { signal: new AbortController().signal } + + const result = await handler(request, extra) + + expect(callback).toHaveBeenCalledWith(extra, request.params) + expect(result).toEqual({ action: 'accept', content: { username: 'alice' } }) + }) + + it.each([{ action: 'decline' as const }, { action: 'cancel' as const }])( + 'elicitation handler returns $action result', + async (callbackResult) => { + const callback: ElicitationCallback = vi.fn().mockResolvedValue(callbackResult) + const { handler } = await connectAndGetElicitationHandler(callback) + const request = { + method: 'elicitation/create', + params: { message: 'Enter username', requestedSchema: { type: 'object' } }, + } + const extra = { signal: new AbortController().signal } + + const result = await handler(request, extra) + + expect(callback).toHaveBeenCalledWith(extra, request.params) + expect(result).toEqual({ action: callbackResult.action }) + } + ) + + it('elicitation handler works for URL mode params', async () => { + const callbackResult = { action: 'accept' as const } + const callback: ElicitationCallback = vi.fn().mockResolvedValue(callbackResult) + const { handler } = await connectAndGetElicitationHandler(callback) + const request = { + method: 'elicitation/create', + params: { + mode: 'url', + message: 'Please authenticate', + url: 'https://example.com/auth', + elicitationId: 'elicit-123', + }, + } + const extra = { signal: new AbortController().signal } + + const result = await handler(request, extra) + + expect(callback).toHaveBeenCalledWith(extra, request.params) + expect(result).toEqual({ action: 'accept' }) + }) + + it('elicitation callback errors propagate', async () => { + const callback: ElicitationCallback = vi.fn().mockRejectedValue(new Error('User cancelled')) + const { handler } = await connectAndGetElicitationHandler(callback) + const request = { + method: 'elicitation/create', + params: { message: 'Confirm?' }, + } + const extra = { signal: new AbortController().signal } + + await expect(handler(request, extra)).rejects.toThrow('User cancelled') + }) + }) + + describe('tools list changed', () => { + let client: McpClient + let sdkClientMock: { + connect: ReturnType + close: ReturnType + listTools: ReturnType + callTool: ReturnType + setRequestHandler: ReturnType + setNotificationHandler: ReturnType + getServerCapabilities: ReturnType + getServerVersion: ReturnType + getInstructions: ReturnType + experimental: { tasks: { callToolStream: ReturnType } } + } + + beforeEach(() => { + client = new McpClient({ applicationName: 'TestApp', transport: mockTransport }) + sdkClientMock = vi.mocked(Client).mock.results.at(-1)!.value + sdkClientMock.connect.mockResolvedValue(undefined) + }) + + function triggerToolsChanged(): void { + const ctorCall = vi.mocked(Client).mock.calls.at(-1)! + ctorCall[1]!.listChanged!.tools!.onChanged(null, null) + } + + it('calls onToolsChanged with old names and new tools when list changes', async () => { + sdkClientMock.listTools.mockResolvedValue({ + tools: [{ name: 'tool_a', description: 'A', inputSchema: {} }], + }) + await client.listTools() + + const onToolsChanged = vi.fn() + client.onToolsChanged = onToolsChanged + + sdkClientMock.listTools.mockResolvedValue({ + tools: [ + { name: 'tool_a', description: 'A', inputSchema: {} }, + { name: 'tool_b', description: 'B', inputSchema: {} }, + ], + }) + + triggerToolsChanged() + await vi.waitFor(() => expect(onToolsChanged).toHaveBeenCalled()) + + expect(onToolsChanged).toHaveBeenCalledWith(['tool_a'], expect.any(Array)) + const newTools = onToolsChanged.mock.calls[0]![1] as McpTool[] + expect(newTools.map((t) => t.name)).toEqual(['tool_a', 'tool_b']) + }) + + it('updates registered tool names after each listTools call', async () => { + sdkClientMock.listTools.mockResolvedValue({ + tools: [ + { name: 'x', description: 'X', inputSchema: {} }, + { name: 'y', description: 'Y', inputSchema: {} }, + ], + }) + await client.listTools() + + const onToolsChanged = vi.fn() + client.onToolsChanged = onToolsChanged + + sdkClientMock.listTools.mockResolvedValue({ + tools: [{ name: 'z', description: 'Z', inputSchema: {} }], + }) + + triggerToolsChanged() + await vi.waitFor(() => expect(onToolsChanged).toHaveBeenCalled()) + + expect(onToolsChanged).toHaveBeenCalledWith(['x', 'y'], expect.any(Array)) + const newTools = onToolsChanged.mock.calls[0]![1] as McpTool[] + expect(newTools.map((t) => t.name)).toEqual(['z']) + }) + + it('does not throw when onToolsChanged is not set', async () => { + sdkClientMock.listTools.mockResolvedValue({ + tools: [{ name: 'tool_a', description: 'A', inputSchema: {} }], + }) + await client.listTools() + + sdkClientMock.listTools.mockResolvedValue({ + tools: [{ name: 'tool_b', description: 'B', inputSchema: {} }], + }) + + triggerToolsChanged() + await new Promise((r) => setTimeout(r, 0)) + }) + + it('logs warning and preserves registry when listTools fails during refresh', async () => { + sdkClientMock.listTools.mockResolvedValue({ + tools: [{ name: 'tool_a', description: 'A', inputSchema: {} }], + }) + await client.listTools() + + const onToolsChanged = vi.fn() + client.onToolsChanged = onToolsChanged + + sdkClientMock.listTools.mockRejectedValue(new Error('server disconnected')) + const warnSpy = vi.spyOn(logger, 'warn') + + triggerToolsChanged() + await vi.waitFor(() => expect(warnSpy).toHaveBeenCalled()) + + expect(onToolsChanged).not.toHaveBeenCalled() + expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining('failed to refresh tools')) + }) + + it('coalesces notifications received during an in-flight refresh into one extra refresh', async () => { + sdkClientMock.listTools.mockResolvedValue({ + tools: [{ name: 'tool_a', description: 'A', inputSchema: {} }], + }) + await client.listTools() + + const onToolsChanged = vi.fn() + client.onToolsChanged = onToolsChanged + + let resolveListTools: (value: unknown) => void + sdkClientMock.listTools.mockReturnValue(new Promise((r) => (resolveListTools = r))) + + triggerToolsChanged() + triggerToolsChanged() + triggerToolsChanged() + + resolveListTools!({ tools: [{ name: 'tool_b', description: 'B', inputSchema: {} }] }) + await vi.waitFor(() => expect(onToolsChanged).toHaveBeenCalledTimes(2)) + + expect(sdkClientMock.listTools).toHaveBeenCalledTimes(3) + }) + }) + + describe('McpTool', () => { + const mockClientWrapper = { callTool: vi.fn() } as unknown as McpClient + const tool = new McpTool({ + name: 'weather', + description: 'Get weather', + inputSchema: {}, + client: mockClientWrapper, + }) + + const toolContext: ToolContext = { + toolUse: { toolUseId: 'id-123', name: 'weather', input: { city: 'NYC' } }, + agent: { cancelSignal: new AbortController().signal } as LocalAgent, + invocationState: {}, + interrupt: () => { + throw new Error('interrupt not available in mock context') + }, + } + + it('forwards agent cancelSignal to callTool', async () => { + vi.mocked(mockClientWrapper.callTool).mockResolvedValue({ + content: [{ type: 'text', text: 'ok' }], + }) + + await runTool(tool.stream(toolContext)) + + expect(mockClientWrapper.callTool).toHaveBeenCalledWith( + tool, + { city: 'NYC' }, + { + signal: toolContext.agent.cancelSignal, + } + ) + }) + + it('returns text results on success', async () => { + vi.mocked(mockClientWrapper.callTool).mockResolvedValue({ + content: [{ type: 'text', text: 'Sunny' }], + }) + + const result = await runTool(tool.stream(toolContext)) + + expect(result).toBeDefined() + expect(result.status).toBe('success') + expect((result.content[0] as TextBlock).text).toBe('Sunny') + }) + + it('returns structured data results on success', async () => { + const data = { temperature: 72 } + vi.mocked(mockClientWrapper.callTool).mockResolvedValue({ + content: [{ type: 'data', value: data }], + }) + + const result = await runTool(tool.stream(toolContext)) + const content = result.content[0] as JsonBlock + + expect(content).toBeInstanceOf(JsonBlock) + expect(content.json).toEqual(expect.objectContaining({ value: data })) + }) + + it('provides default message for empty output', async () => { + vi.mocked(mockClientWrapper.callTool).mockResolvedValue({ content: [] }) + + const result = await runTool(tool.stream(toolContext)) + + expect((result.content[0] as TextBlock).text).toContain('completed successfully') + }) + + it('handles protocol-level errors', async () => { + vi.mocked(mockClientWrapper.callTool).mockResolvedValue({ + isError: true, + content: [{ type: 'text', text: 'Service Unavailable' }], + }) + + const result = await runTool(tool.stream(toolContext)) + + expect(result.status).toBe('error') + expect((result.content[0] as TextBlock).text).toBe('Service Unavailable') + }) + + it('catches and wraps client exceptions', async () => { + vi.mocked(mockClientWrapper.callTool).mockRejectedValue(new Error('Network Error')) + + const result = await runTool(tool.stream(toolContext)) + + expect(result.status).toBe('error') + expect((result.content[0] as TextBlock).text).toBe('Network Error') + }) + + it('validates SDK response format', async () => { + vi.mocked(mockClientWrapper.callTool).mockResolvedValue({ content: null }) + + const result = await runTool(tool.stream(toolContext)) + + expect(result.status).toBe('error') + expect((result.content[0] as TextBlock).text).toContain('missing content array') + }) + + it('maps MCP image content to ImageBlock', async () => { + // "iVBOR..." is a minimal base64 PNG prefix + const base64Data = 'iVBORw0KGgoAAAANSUhEUg==' + vi.mocked(mockClientWrapper.callTool).mockResolvedValue({ + content: [{ type: 'image', data: base64Data, mimeType: 'image/png' }], + }) + + const result = await runTool(tool.stream(toolContext)) + + expect(result.status).toBe('success') + expect(result.content).toHaveLength(1) + const imageBlock = result.content[0] as ImageBlock + expect(imageBlock).toBeInstanceOf(ImageBlock) + expect(imageBlock.format).toBe('png') + expect(imageBlock.source.type).toBe('imageSourceBytes') + }) + + it('falls back to JsonBlock for unsupported image mime type', async () => { + vi.mocked(mockClientWrapper.callTool).mockResolvedValue({ + content: [{ type: 'image', data: 'abc123', mimeType: 'image/bmp' }], + }) + + const result = await runTool(tool.stream(toolContext)) + + expect(result.content[0]).toBeInstanceOf(JsonBlock) + }) + + it('falls back to JsonBlock for image content missing data', async () => { + vi.mocked(mockClientWrapper.callTool).mockResolvedValue({ + content: [{ type: 'image', mimeType: 'image/png' }], + }) + + const result = await runTool(tool.stream(toolContext)) + + expect(result.content[0]).toBeInstanceOf(JsonBlock) + }) + + it('maps MCP text resource to TextBlock', async () => { + vi.mocked(mockClientWrapper.callTool).mockResolvedValue({ + content: [ + { type: 'resource', resource: { uri: 'file:///doc.txt', text: 'hello world', mimeType: 'text/plain' } }, + ], + }) + + const result = await runTool(tool.stream(toolContext)) + + expect(result.status).toBe('success') + expect((result.content[0] as TextBlock).text).toBe('hello world') + }) + + it('maps MCP blob resource with image mime type to ImageBlock', async () => { + const base64Data = 'iVBORw0KGgoAAAANSUhEUg==' + vi.mocked(mockClientWrapper.callTool).mockResolvedValue({ + content: [{ type: 'resource', resource: { uri: 'file:///img.png', blob: base64Data, mimeType: 'image/jpeg' } }], + }) + + const result = await runTool(tool.stream(toolContext)) + + expect(result.content[0]).toBeInstanceOf(ImageBlock) + expect((result.content[0] as ImageBlock).format).toBe('jpeg') + }) + + it('falls back to JsonBlock for blob resource with non-image mime type', async () => { + vi.mocked(mockClientWrapper.callTool).mockResolvedValue({ + content: [ + { type: 'resource', resource: { uri: 'file:///doc.pdf', blob: 'abc123', mimeType: 'application/pdf' } }, + ], + }) + + const result = await runTool(tool.stream(toolContext)) + + expect(result.content[0]).toBeInstanceOf(JsonBlock) + }) + + it('falls back to JsonBlock for resource with neither text nor blob', async () => { + vi.mocked(mockClientWrapper.callTool).mockResolvedValue({ + content: [{ type: 'resource', resource: { uri: 'file:///unknown' } }], + }) + + const result = await runTool(tool.stream(toolContext)) + + expect(result.content[0]).toBeInstanceOf(JsonBlock) + }) + + it('handles mixed content types in a single result', async () => { + const base64Data = 'iVBORw0KGgoAAAANSUhEUg==' + vi.mocked(mockClientWrapper.callTool).mockResolvedValue({ + content: [ + { type: 'text', text: 'Here is the image:' }, + { type: 'image', data: base64Data, mimeType: 'image/png' }, + { type: 'resource', resource: { uri: 'file:///notes.txt', text: 'Some notes' } }, + ], + }) + + const result = await runTool(tool.stream(toolContext)) + + expect(result.content).toHaveLength(3) + expect((result.content[0] as TextBlock).text).toBe('Here is the image:') + expect(result.content[1]).toBeInstanceOf(ImageBlock) + expect((result.content[2] as TextBlock).text).toBe('Some notes') + }) + + it('surfaces elicitation data for McpError with code -32042', async () => { + const elicitations = [ + { + mode: 'url', + message: 'Please authorize with GitHub', + elicitationId: 'e-123', + url: 'https://github.com/login/oauth/authorize?client_id=abc', + }, + ] + const mcpError = new McpError(ErrorCode.UrlElicitationRequired, 'Authorization required', { elicitations }) + vi.mocked(mockClientWrapper.callTool).mockRejectedValue(mcpError) + + const result = await runTool(tool.stream(toolContext)) + + expect(result.status).toBe('error') + expect((result.content[0] as TextBlock).text).toBe( + `MCP Elicitation required: [${String(mcpError)}] with data ${JSON.stringify(elicitations)}` + ) + }) + + it('surfaces multiple elicitations for McpError with code -32042', async () => { + const elicitations = [ + { + mode: 'url', + message: 'Authorize with GitHub', + elicitationId: 'e-1', + url: 'https://github.com/login/oauth/authorize', + }, + { + mode: 'url', + message: 'Authorize with Google', + elicitationId: 'e-2', + url: 'https://accounts.google.com/o/oauth2/auth', + }, + ] + const mcpError = new McpError(ErrorCode.UrlElicitationRequired, 'Authorization required', { elicitations }) + vi.mocked(mockClientWrapper.callTool).mockRejectedValue(mcpError) + + const result = await runTool(tool.stream(toolContext)) + + expect(result.status).toBe('error') + expect((result.content[0] as TextBlock).text).toBe( + `MCP Elicitation required: [${String(mcpError)}] with data ${JSON.stringify(elicitations)}` + ) + }) + + it('falls through to generic error for McpError -32042 with malformed data', async () => { + const mcpError = new McpError(ErrorCode.UrlElicitationRequired, 'Authorization required', { + unexpected: 'shape', + }) + vi.mocked(mockClientWrapper.callTool).mockRejectedValue(mcpError) + + const result = await runTool(tool.stream(toolContext)) + + expect(result.status).toBe('error') + expect((result.content[0] as TextBlock).text).toBe('MCP error -32042: Authorization required') + }) + + it('surfaces elicitation data for UrlElicitationRequiredError', async () => { + const elicitations = [ + { + mode: 'url' as const, + message: 'Please authorize', + elicitationId: 'e-1', + url: 'https://example.com/auth', + }, + ] + const error = new UrlElicitationRequiredError(elicitations, 'Auth required') + vi.mocked(mockClientWrapper.callTool).mockRejectedValue(error) + + const result = await runTool(tool.stream(toolContext)) + + expect(result.status).toBe('error') + expect((result.content[0] as TextBlock).text).toContain('MCP Elicitation required') + expect((result.content[0] as TextBlock).text).toContain('https://example.com/auth') + }) + + it('falls through to generic error for McpError -32042 with undefined data', async () => { + const mcpError = new McpError(ErrorCode.UrlElicitationRequired, 'Auth required') + vi.mocked(mockClientWrapper.callTool).mockRejectedValue(mcpError) + + const result = await runTool(tool.stream(toolContext)) + + expect(result.status).toBe('error') + expect((result.content[0] as TextBlock).text).toBe('MCP error -32042: Auth required') + }) + + it('falls through to generic error for McpError -32042 with non-array elicitations', async () => { + const mcpError = new McpError(ErrorCode.UrlElicitationRequired, 'Auth required', { + elicitations: 'not-an-array', + }) + vi.mocked(mockClientWrapper.callTool).mockRejectedValue(mcpError) + + const result = await runTool(tool.stream(toolContext)) + + expect(result.status).toBe('error') + expect((result.content[0] as TextBlock).text).toBe('MCP error -32042: Auth required') + }) + + it('falls through to generic error for McpError -32042 with empty elicitations', async () => { + const mcpError = new McpError(ErrorCode.UrlElicitationRequired, 'Auth required', { + elicitations: [], + }) + vi.mocked(mockClientWrapper.callTool).mockRejectedValue(mcpError) + + const result = await runTool(tool.stream(toolContext)) + + expect(result.status).toBe('error') + expect((result.content[0] as TextBlock).text).toBe('MCP error -32042: Auth required') + }) + + it('falls through to generic error for McpError with a different code', async () => { + const mcpError = new McpError(ErrorCode.InvalidRequest, 'Bad request') + vi.mocked(mockClientWrapper.callTool).mockRejectedValue(mcpError) + + const result = await runTool(tool.stream(toolContext)) + + expect(result.status).toBe('error') + expect((result.content[0] as TextBlock).text).toBe('MCP error -32600: Bad request') + }) + }) +}) + +describe('server metadata getters', () => { + let client: McpClient + let sdkClientMock: { + connect: ReturnType + getServerCapabilities: ReturnType + getServerVersion: ReturnType + getInstructions: ReturnType + setNotificationHandler: ReturnType + setRequestHandler: ReturnType + } + + beforeEach(() => { + vi.clearAllMocks() + client = new McpClient({ applicationName: 'TestApp', transport: mockTransport }) + sdkClientMock = vi.mocked(Client).mock.results.at(-1)!.value + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + it('returns undefined for all getters before connect', () => { + sdkClientMock.getServerCapabilities.mockReturnValue(undefined) + sdkClientMock.getServerVersion.mockReturnValue(undefined) + sdkClientMock.getInstructions.mockReturnValue(undefined) + + expect(client.serverCapabilities).toBeUndefined() + expect(client.serverVersion).toBeUndefined() + expect(client.serverInstructions).toBeUndefined() + }) + + it('returns serverCapabilities after connect', async () => { + const caps = { tools: {} } + sdkClientMock.getServerCapabilities.mockReturnValue(caps) + + await client.connect() + + expect(client.serverCapabilities).toBe(caps) + }) + + it('returns serverVersion after connect', async () => { + const version = { name: 'my-server', version: '1.2.3' } + sdkClientMock.getServerVersion.mockReturnValue(version) + + await client.connect() + + expect(client.serverVersion).toBe(version) + }) + + it('returns serverInstructions after connect', async () => { + sdkClientMock.getInstructions.mockReturnValue('Use this server for X.') + + await client.connect() + + expect(client.serverInstructions).toBe('Use this server for X.') + }) + + it('connectionState is disconnected before connect', () => { + expect(client.connectionState).toBe('disconnected') + }) + + it('connectionState is connected after successful connect', async () => { + await client.connect() + expect(client.connectionState).toBe('connected') + }) +}) + +describe('continueOnError', () => { + let sdkClientMock: { + connect: ReturnType + listTools: ReturnType + callTool: ReturnType + setNotificationHandler: ReturnType + setRequestHandler: ReturnType + getServerCapabilities: ReturnType + getServerVersion: ReturnType + getInstructions: ReturnType + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + it('throws on connection failure by default', async () => { + const client = new McpClient({ applicationName: 'TestApp', transport: mockTransport }) + sdkClientMock = vi.mocked(Client).mock.results.at(-1)!.value + sdkClientMock.connect.mockRejectedValue(new Error('connection refused')) + + await expect(client.connect()).rejects.toThrow('connection refused') + }) + + it('swallows connection failure when continueOnError is true', async () => { + const client = new McpClient({ applicationName: 'TestApp', transport: mockTransport, continueOnError: true }) + sdkClientMock = vi.mocked(Client).mock.results.at(-1)!.value + sdkClientMock.connect.mockRejectedValue(new Error('connection refused')) + + await expect(client.connect()).resolves.toBeUndefined() + }) + + it('logs a warning when continueOnError swallows a connection failure', async () => { + const warnSpy = vi.spyOn(logger, 'warn').mockImplementation(() => {}) + const client = new McpClient({ applicationName: 'TestApp', transport: mockTransport, continueOnError: true }) + sdkClientMock = vi.mocked(Client).mock.results.at(-1)!.value + sdkClientMock.connect.mockRejectedValue(new Error('connection refused')) + + await client.connect() + + expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining('MCP server failed to connect')) + }) + + it('listTools returns empty array when continueOnError and connection failed', async () => { + const client = new McpClient({ applicationName: 'TestApp', transport: mockTransport, continueOnError: true }) + sdkClientMock = vi.mocked(Client).mock.results.at(-1)!.value + sdkClientMock.connect.mockRejectedValue(new Error('connection refused')) + + const tools = await client.listTools() + + expect(tools).toEqual([]) + }) + + it('callTool throws when continueOnError and connection failed', async () => { + const client = new McpClient({ applicationName: 'TestApp', transport: mockTransport, continueOnError: true }) + sdkClientMock = vi.mocked(Client).mock.results.at(-1)!.value + sdkClientMock.connect.mockRejectedValue(new Error('connection refused')) + const tool = new McpTool({ name: 'my_tool', description: '', inputSchema: {}, client }) + + await expect(client.callTool(tool, {})).rejects.toThrow( + 'MCP server failed to connect. Call connect(true) to retry.' + ) + }) + + it('does not retry connection on subsequent calls after continueOnError failure', async () => { + const client = new McpClient({ applicationName: 'TestApp', transport: mockTransport, continueOnError: true }) + sdkClientMock = vi.mocked(Client).mock.results.at(-1)!.value + sdkClientMock.connect.mockRejectedValue(new Error('connection refused')) + + await client.listTools() + await client.listTools() + + expect(sdkClientMock.connect).toHaveBeenCalledTimes(1) + }) + + it('recovers after explicit connect(true) when server comes back', async () => { + const client = new McpClient({ applicationName: 'TestApp', transport: mockTransport, continueOnError: true }) + sdkClientMock = vi.mocked(Client).mock.results.at(-1)!.value + sdkClientMock.connect.mockRejectedValueOnce(new Error('connection refused')) + sdkClientMock.listTools.mockResolvedValue({ tools: [] }) + + const firstTools = await client.listTools() + expect(firstTools).toEqual([]) + expect(client.connectionState).toBe('failed') + + await client.connect(true) + const secondTools = await client.listTools() + + expect(secondTools).toEqual([]) + expect(client.connectionState).toBe('connected') + expect(sdkClientMock.connect).toHaveBeenCalledTimes(2) + }) +}) + +describe('log routing', () => { + let notificationHandler: (notification: { params: LoggingMessageNotificationParams }) => void + let sdkClientMock: { + connect: ReturnType + setNotificationHandler: ReturnType + setRequestHandler: ReturnType + getServerCapabilities: ReturnType + getServerVersion: ReturnType + getInstructions: ReturnType + } + + beforeEach(() => { + vi.clearAllMocks() + new McpClient({ applicationName: 'TestApp', transport: mockTransport }) + sdkClientMock = vi.mocked(Client).mock.results.at(-1)!.value + // Handler is registered in the constructor — read it from the first setNotificationHandler call + notificationHandler = sdkClientMock.setNotificationHandler.mock.calls[0]![1] + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + it('routes debug level to logger.debug', () => { + const spy = vi.spyOn(logger, 'debug').mockImplementation(() => {}) + notificationHandler({ params: { level: 'debug', data: 'hello' } }) + expect(spy).toHaveBeenCalledWith(expect.stringContaining('hello')) + }) + + it('routes info level to logger.info', () => { + const spy = vi.spyOn(logger, 'info').mockImplementation(() => {}) + notificationHandler({ params: { level: 'info', data: 'hello' } }) + expect(spy).toHaveBeenCalledWith(expect.stringContaining('hello')) + }) + + it('routes notice level to logger.info', () => { + const spy = vi.spyOn(logger, 'info').mockImplementation(() => {}) + notificationHandler({ params: { level: 'notice', data: 'hello' } }) + expect(spy).toHaveBeenCalledWith(expect.stringContaining('hello')) + }) + + it('routes warning level to logger.warn', () => { + const spy = vi.spyOn(logger, 'warn').mockImplementation(() => {}) + notificationHandler({ params: { level: 'warning', data: 'hello' } }) + expect(spy).toHaveBeenCalledWith(expect.stringContaining('hello')) + }) + + it('routes error level to logger.error', () => { + const spy = vi.spyOn(logger, 'error').mockImplementation(() => {}) + notificationHandler({ params: { level: 'error', data: 'hello' } }) + expect(spy).toHaveBeenCalledWith(expect.stringContaining('hello')) + }) + + it('routes critical level to logger.error', () => { + const spy = vi.spyOn(logger, 'error').mockImplementation(() => {}) + notificationHandler({ params: { level: 'critical', data: 'hello' } }) + expect(spy).toHaveBeenCalledWith(expect.stringContaining('hello')) + }) + + it('routes alert level to logger.error', () => { + const spy = vi.spyOn(logger, 'error').mockImplementation(() => {}) + notificationHandler({ params: { level: 'alert', data: 'hello' } }) + expect(spy).toHaveBeenCalledWith(expect.stringContaining('hello')) + }) + + it('routes emergency level to logger.error', () => { + const spy = vi.spyOn(logger, 'error').mockImplementation(() => {}) + notificationHandler({ params: { level: 'emergency', data: 'hello' } }) + expect(spy).toHaveBeenCalledWith(expect.stringContaining('hello')) + }) + + it('includes logger name and data in the message', () => { + const spy = vi.spyOn(logger, 'info').mockImplementation(() => {}) + notificationHandler({ params: { level: 'info', logger: 'my-server', data: { key: 'val' } } }) + expect(spy).toHaveBeenCalledWith(expect.stringContaining('my-server')) + expect(spy).toHaveBeenCalledWith(expect.stringContaining('key')) + }) + + it('calls custom logHandler when provided', () => { + const customHandler = vi.fn() + new McpClient({ applicationName: 'TestApp', transport: mockTransport, logHandler: customHandler }) + const customSdkMock = vi.mocked(Client).mock.results.at(-1)!.value + const capturedHandler = customSdkMock.setNotificationHandler.mock.calls[0]![1] + + const params: LoggingMessageNotificationParams = { level: 'info', data: 'test' } + capturedHandler({ params }) + + expect(customHandler).toHaveBeenCalledWith(params) + }) +}) + +describe('McpClient transport resolution', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('constructs StreamableHTTPClientTransport when url is provided', () => { + new McpClient({ url: 'https://mcp.example.com' }) + expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(new URL('https://mcp.example.com'), {}) + }) + + it('constructs ClientCredentialsProvider when auth is provided', () => { + new McpClient({ url: 'https://mcp.example.com', auth: { clientId: 'id', clientSecret: 'secret' } }) + expect(ClientCredentialsProvider).toHaveBeenCalledWith({ clientId: 'id', clientSecret: 'secret' }) + expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(new URL('https://mcp.example.com'), { + authProvider: expect.anything(), + }) + }) + + it('passes scopes as space-separated string', () => { + new McpClient({ + url: 'https://mcp.example.com', + auth: { clientId: 'id', clientSecret: 'secret', scopes: ['read', 'write'] }, + }) + expect(ClientCredentialsProvider).toHaveBeenCalledWith({ + clientId: 'id', + clientSecret: 'secret', + scope: 'read write', + }) + }) + + it('passes custom authProvider to transport', () => { + const customProvider = { redirectUrl: undefined, clientMetadata: {} } as never + new McpClient({ url: 'https://mcp.example.com', authProvider: customProvider }) + expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(new URL('https://mcp.example.com'), { + authProvider: customProvider, + }) + }) + + it('throws when both transport and url are provided', () => { + expect(() => new McpClient({ transport: mockTransport, url: 'https://mcp.example.com' } as never)).toThrow( + 'provide either "transport" or "url", not both' + ) + }) + + it('throws when neither transport nor url is provided', () => { + expect(() => new McpClient({} as never)).toThrow('either "transport" or "url" must be provided') + }) + + it('throws when auth is provided with transport', () => { + expect( + () => new McpClient({ transport: mockTransport, auth: { clientId: 'x', clientSecret: 'y' } } as never) + ).toThrow('"auth", "authProvider", and "headers" require "url"') + }) + + it('throws when both auth and authProvider are provided', () => { + const customProvider = {} as never + expect( + () => + new McpClient({ + url: 'https://mcp.example.com', + auth: { clientId: 'x', clientSecret: 'y' }, + authProvider: customProvider, + } as never) + ).toThrow('provide either "auth" or "authProvider", not both') + }) + + it('accepts URL instance for url field', () => { + const url = new URL('https://mcp.example.com/path') + new McpClient({ url }) + expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(url, {}) + }) + + it('passes headers as requestInit to transport', () => { + new McpClient({ url: 'https://mcp.example.com', headers: { 'X-Api-Key': 'abc' } }) + expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(new URL('https://mcp.example.com'), { + requestInit: { headers: { 'X-Api-Key': 'abc' } }, + }) + }) + + it('passes both auth and headers to transport', () => { + new McpClient({ + url: 'https://mcp.example.com', + auth: { clientId: 'id', clientSecret: 'secret' }, + headers: { 'X-Trace': '123' }, + }) + expect(ClientCredentialsProvider).toHaveBeenCalledWith({ clientId: 'id', clientSecret: 'secret' }) + expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(new URL('https://mcp.example.com'), { + authProvider: expect.anything(), + requestInit: { headers: { 'X-Trace': '123' } }, + }) + }) + + it('throws when headers is provided with transport', () => { + expect(() => new McpClient({ transport: mockTransport, headers: { 'X-Foo': 'bar' } } as never)).toThrow( + '"auth", "authProvider", and "headers" require "url"' + ) + }) +}) diff --git a/strands-ts/src/__tests__/mime.test.ts b/strands-ts/src/__tests__/mime.test.ts new file mode 100644 index 0000000000..1e51c9645c --- /dev/null +++ b/strands-ts/src/__tests__/mime.test.ts @@ -0,0 +1,88 @@ +import { describe, it, expect } from 'vitest' +import { toMimeType, toMediaFormat } from '../mime.js' + +describe('toMimeType', () => { + it.each([ + ['png', 'image/png'], + ['jpg', 'image/jpeg'], + ['jpeg', 'image/jpeg'], + ['gif', 'image/gif'], + ['webp', 'image/webp'], + ['mkv', 'video/x-matroska'], + ['mov', 'video/quicktime'], + ['mp4', 'video/mp4'], + ['webm', 'video/webm'], + ['flv', 'video/x-flv'], + ['mpeg', 'video/mpeg'], + ['mpg', 'video/mpeg'], + ['wmv', 'video/x-ms-wmv'], + ['3gp', 'video/3gpp'], + ['pdf', 'application/pdf'], + ['csv', 'text/csv'], + ['doc', 'application/msword'], + ['docx', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'], + ['xls', 'application/vnd.ms-excel'], + ['xlsx', 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'], + ['html', 'text/html'], + ['txt', 'text/plain'], + ['md', 'text/markdown'], + ['json', 'application/json'], + ['xml', 'application/xml'], + ])('converts %s to %s', (mediaFormat, mimeType) => { + expect(toMimeType(mediaFormat)).toBe(mimeType) + }) + + it('is case-insensitive', () => { + expect(toMimeType('PNG')).toBe('image/png') + expect(toMimeType('Mp4')).toBe('video/mp4') + expect(toMimeType('PDF')).toBe('application/pdf') + }) + + it('returns undefined for unknown formats', () => { + expect(toMimeType('unknown')).toBeUndefined() + expect(toMimeType('bmp')).toBeUndefined() + expect(toMimeType('')).toBeUndefined() + }) +}) + +describe('toMediaFormat', () => { + it.each([ + ['image/png', 'png'], + ['image/jpeg', 'jpeg'], + ['image/gif', 'gif'], + ['image/webp', 'webp'], + ['video/x-matroska', 'mkv'], + ['video/quicktime', 'mov'], + ['video/mp4', 'mp4'], + ['video/webm', 'webm'], + ['video/x-flv', 'flv'], + ['video/mpeg', 'mpeg'], + ['video/x-ms-wmv', 'wmv'], + ['video/3gpp', '3gp'], + ['application/pdf', 'pdf'], + ['text/csv', 'csv'], + ['application/msword', 'doc'], + ['application/vnd.openxmlformats-officedocument.wordprocessingml.document', 'docx'], + ['application/vnd.ms-excel', 'xls'], + ['application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', 'xlsx'], + ['text/html', 'html'], + ['text/plain', 'txt'], + ['text/markdown', 'md'], + ['application/json', 'json'], + ['application/xml', 'xml'], + ])('converts %s to %s', (mimeType, mediaFormat) => { + expect(toMediaFormat(mimeType)).toBe(mediaFormat) + }) + + it('is case-insensitive', () => { + expect(toMediaFormat('IMAGE/PNG')).toBe('png') + expect(toMediaFormat('Video/Mp4')).toBe('mp4') + expect(toMediaFormat('Application/PDF')).toBe('pdf') + }) + + it('returns undefined for unknown MIME types', () => { + expect(toMediaFormat('image/bmp')).toBeUndefined() + expect(toMediaFormat('application/octet-stream')).toBeUndefined() + expect(toMediaFormat('')).toBeUndefined() + }) +}) diff --git a/strands-ts/src/__tests__/state-store.test.ts b/strands-ts/src/__tests__/state-store.test.ts new file mode 100644 index 0000000000..00e5e3366e --- /dev/null +++ b/strands-ts/src/__tests__/state-store.test.ts @@ -0,0 +1,395 @@ +import { describe, expect, it } from 'vitest' +import { StateStore } from '../state-store.js' +import { + isStateSerializable, + loadStateFromJSONSymbol, + loadStateSerializable, + serializeStateSerializable, + stateToJSONSymbol, +} from '../types/serializable.js' + +describe('StateStore', () => { + describe('constructor', () => { + it('creates empty state when no initial state provided', () => { + const state = new StateStore() + expect(state.keys()).toEqual([]) + }) + + it('creates state with initial values', () => { + const state = new StateStore({ key1: 'value1', key2: 42 }) + expect(state.get('key1')).toBe('value1') + expect(state.get('key2')).toBe(42) + }) + + it('stores deep copy of initial state', () => { + const initial = { nested: { value: 'test' } } + const state = new StateStore(initial) + + // Mutate original + initial.nested.value = 'changed' + + // State should not be affected + expect(state.get('nested')).toEqual({ value: 'test' }) + }) + + it('throws error for function in initial state', () => { + const invalidState = { func: () => 'test', value: 'keep' } + expect(() => new StateStore(invalidState as never)).toThrow( + 'initialState.func contains a function which cannot be serialized' + ) + }) + + it('throws error for symbol in initial state', () => { + const sym = Symbol('test') + const invalidState = { sym, value: 'keep' } + expect(() => new StateStore(invalidState as never)).toThrow( + 'initialState.sym contains a symbol which cannot be serialized' + ) + }) + + it('throws error for undefined in initial state', () => { + const invalidState = { undef: undefined, value: 'keep' } + expect(() => new StateStore(invalidState as never)).toThrow( + 'initialState.undef is undefined which cannot be serialized' + ) + }) + + it('throws error for nested function in initial state', () => { + const invalidState = { nested: { func: () => 'test' } } + expect(() => new StateStore(invalidState as never)).toThrow( + 'initialState.nested.func contains a function which cannot be serialized' + ) + }) + + it('throws error for function in array in initial state', () => { + const invalidState = { arr: [1, () => 'test', 3] } + expect(() => new StateStore(invalidState as never)).toThrow( + 'initialState.arr[1] contains a function which cannot be serialized' + ) + }) + }) + + describe('get', () => { + it('throws error when key is null or undefined', () => { + const state = new StateStore() + expect(() => state.get(null as any)).toThrow('key is required') + expect(() => state.get(undefined as any)).toThrow('key is required') + }) + + it('returns undefined when key does not exist', () => { + const state = new StateStore() + expect(state.get('nonexistent')).toBeUndefined() + }) + + it('returns value when key exists', () => { + const state = new StateStore({ key1: 'value1' }) + expect(state.get('key1')).toBe('value1') + }) + + it('returns deep copy that cannot mutate stored state', () => { + const state = new StateStore({ nested: { value: 'test' } }) + const retrieved = state.get<{ nested: { value: string } }>('nested') + + // Mutate retrieved value + retrieved!.value = 'changed' + + // Stored state should not be affected + expect(state.get('nested')).toEqual({ value: 'test' }) + }) + + it('infers correct type with generic state interface', () => { + interface TestState { + user: { name: string; age: number } + count: number + items: string[] + } + + const state = new StateStore({ user: { name: 'John', age: 30 }, count: 5, items: ['a', 'b'] }) + + // Type inference tests + const user = state.get('user') + const count = state.get('count') + const items = state.get('items') + + expect(user).toEqual({ name: 'John', age: 30 }) + expect(count).toBe(5) + expect(items).toEqual(['a', 'b']) + }) + + it('returns undefined for non-existent key with typed interface', () => { + interface TestState { + existing: string + } + + const state = new StateStore({ existing: 'value' }) + const result = state.get('existing') + + expect(result).toBe('value') + + // Non-existent key + const state2 = new StateStore() + const missing = state2.get('existing') + + expect(missing).toBeUndefined() + + // @ts-expect-error properties not on the TestsState are an error + state2.get('not-really') + }) + }) + + describe('set', () => { + it('sets string value successfully', () => { + const state = new StateStore() + state.set('key1', 'value1') + expect(state.get('key1')).toBe('value1') + }) + + it('sets number value successfully', () => { + const state = new StateStore() + state.set('key1', 42) + expect(state.get('key1')).toBe(42) + }) + + it('sets boolean value successfully', () => { + const state = new StateStore() + state.set('key1', true) + expect(state.get('key1')).toBe(true) + }) + + it('sets null value successfully', () => { + const state = new StateStore() + state.set('key1', null) + expect(state.get('key1')).toBeNull() + }) + + it('sets object value successfully', () => { + const state = new StateStore() + state.set('key1', { nested: 'value' }) + expect(state.get('key1')).toEqual({ nested: 'value' }) + }) + + it('sets array value successfully', () => { + const state = new StateStore() + state.set('key1', [1, 2, 3]) + expect(state.get('key1')).toEqual([1, 2, 3]) + }) + + it('overwrites existing value', () => { + const state = new StateStore({ key1: 'old' }) + state.set('key1', 'new') + expect(state.get('key1')).toBe('new') + }) + + it('stores deep copy that cannot mutate stored state', () => { + const state = new StateStore() + const value = { nested: { value: 'test' } } + state.set('key1', value) + + // Mutate original + value.nested.value = 'changed' + + // Stored state should not be affected + expect(state.get('key1')).toEqual({ nested: { value: 'test' } }) + }) + + it('throws error for function in value', () => { + const state = new StateStore({ existing: 'value' }) + const obj = { func: () => 'test', value: 'keep' } + expect(() => state.set('key1', obj)).toThrow( + 'value for key "key1".func contains a function which cannot be serialized' + ) + }) + + it('throws error for symbol in value', () => { + const state = new StateStore() + const sym = Symbol('test') + expect(() => state.set('key1', { sym } as never)).toThrow( + 'value for key "key1".sym contains a symbol which cannot be serialized' + ) + }) + + it('throws error for nested function in value', () => { + const state = new StateStore() + const obj = { nested: { func: () => 'test' } } + expect(() => state.set('key1', obj)).toThrow( + 'value for key "key1".nested.func contains a function which cannot be serialized' + ) + }) + + it('throws error for function in array', () => { + const state = new StateStore() + const arr = [1, () => 'test', 3] + expect(() => state.set('key1', arr)).toThrow( + 'value for key "key1"[1] contains a function which cannot be serialized' + ) + }) + + it('throws error for top-level symbol values', () => { + const state = new StateStore() + expect(() => state.set('key1', Symbol('test'))).toThrow( + 'value for key "key1" contains a symbol which cannot be serialized' + ) + }) + + it('throws error for top-level undefined values', () => { + const state = new StateStore() + expect(() => state.set('key1', undefined)).toThrow('value for key "key1" is undefined which cannot be serialized') + }) + + it('accepts typed value with generic state interface', () => { + interface TestState { + user: { name: string; age: number } + count: number + } + + const state = new StateStore() + + state.set('user', { name: 'Alice', age: 25 }) + state.set('count', 10) + + expect(state.get('user')).toEqual({ name: 'Alice', age: 25 }) + expect(state.get('count')).toBe(10) + + // @ts-expect-error properties not on the TestsState are an error + state.set('not-really', 'nope') + }) + }) + + describe('delete', () => { + it('removes existing key', () => { + const state = new StateStore({ key1: 'value1', key2: 'value2' }) + state.delete('key1') + expect(state.get('key1')).toBeUndefined() + expect(state.get('key2')).toBe('value2') + }) + + it('does not throw error for non-existent key', () => { + const state = new StateStore() + expect(() => state.delete('nonexistent')).not.toThrow() + }) + + it('supports typed usage with generic state interface', () => { + interface TestState { + user: { name: string } + count: number + } + + const state = new StateStore({ user: { name: 'Alice' }, count: 5 }) + + // Typed delete + state.delete('user') + expect(state.get('user')).toBeUndefined() + expect(state.get('count')).toBe(5) + }) + }) + + describe('clear', () => { + it('removes all values', () => { + const state = new StateStore({ key1: 'value1', key2: 'value2' }) + state.clear() + expect(state.keys()).toEqual([]) + expect(state.get('key1')).toBeUndefined() + expect(state.get('key2')).toBeUndefined() + }) + + it('works on empty state', () => { + const state = new StateStore() + expect(() => state.clear()).not.toThrow() + expect(state.keys()).toEqual([]) + }) + }) + + describe('getAll', () => { + it('returns object with all state', () => { + const state = new StateStore({ key1: 'value1', key2: 42 }) + expect(state.getAll()).toEqual({ key1: 'value1', key2: 42 }) + }) + + it('returns empty object for empty state', () => { + const state = new StateStore() + expect(state.getAll()).toEqual({}) + }) + }) + + describe('keys', () => { + it('returns array of all keys', () => { + const state = new StateStore({ key1: 'value1', key2: 'value2' }) + expect(state.keys().sort()).toEqual(['key1', 'key2']) + }) + + it('returns empty array for empty state', () => { + const state = new StateStore() + expect(state.keys()).toEqual([]) + }) + + it('returns new array each time', () => { + const state = new StateStore({ key1: 'value1' }) + const keys1 = state.keys() + const keys2 = state.keys() + expect(keys1).not.toBe(keys2) + }) + }) + + describe('stateToJSONSymbol (via symbol)', () => { + it('returns deep copy of state', () => { + const state = new StateStore({ key1: 'value1', nested: { deep: true } }) + const json = state[stateToJSONSymbol]() + expect(json).toEqual({ key1: 'value1', nested: { deep: true } }) + }) + + it('can be accessed via serializeStateSerializable helper', () => { + const state = new StateStore({ key1: 'value1' }) + const json = serializeStateSerializable(state) + expect(json).toEqual({ key1: 'value1' }) + }) + }) + + describe('loadStateFromJSONSymbol (via symbol)', () => { + it('replaces state with json data', () => { + const state = new StateStore({ old: 'data' }) + state[loadStateFromJSONSymbol]({ new: 'data', count: 42 }) + expect(state.getAll()).toEqual({ new: 'data', count: 42 }) + }) + + it('clears state when given non-object', () => { + const state = new StateStore({ key: 'value' }) + state[loadStateFromJSONSymbol](null) + expect(state.getAll()).toEqual({}) + }) + + it('can be accessed via loadStateSerializable helper', () => { + const state = new StateStore({ old: 'data' }) + loadStateSerializable(state, { new: 'data' }) + expect(state.getAll()).toEqual({ new: 'data' }) + }) + }) + + describe('isStateSerializable', () => { + it('returns true for StateStore instances', () => { + const state = new StateStore() + expect(isStateSerializable(state)).toBe(true) + }) + + it('returns false for plain objects', () => { + const obj = { toJSON: () => ({}), loadStateFromJson: () => {} } + expect(isStateSerializable(obj)).toBe(false) + }) + + it('returns false for null', () => { + expect(isStateSerializable(null)).toBe(false) + }) + + it('returns false for objects with only one symbol method', () => { + const partial = { [stateToJSONSymbol]: () => ({}) } + expect(isStateSerializable(partial)).toBe(false) + }) + + it('returns true for objects implementing both symbol methods', () => { + const custom = { + [stateToJSONSymbol]: () => ({ custom: true }), + [loadStateFromJSONSymbol]: () => {}, + } + expect(isStateSerializable(custom)).toBe(true) + }) + }) +}) diff --git a/strands-ts/src/a2a/__tests__/a2a-agent.test.ts b/strands-ts/src/a2a/__tests__/a2a-agent.test.ts new file mode 100644 index 0000000000..4414ff6886 --- /dev/null +++ b/strands-ts/src/a2a/__tests__/a2a-agent.test.ts @@ -0,0 +1,462 @@ +import { describe, expect, it, vi, beforeEach } from 'vitest' +import { A2AAgent } from '../a2a-agent.js' +import { A2AStreamUpdateEvent, A2AResultEvent } from '../events.js' +import type { + AgentCard, + Task, + Message as A2AMessage, + TaskArtifactUpdateEvent, + TaskStatusUpdateEvent, +} from '@a2a-js/sdk' +import { TextBlock, Message } from '../../types/messages.js' +import type { InvokeArgs } from '../../types/agent.js' + +// Mock the A2A SDK client +const mockSendMessageStream = vi.fn() +const mockGetAgentCard = vi.fn() + +vi.mock('@a2a-js/sdk/client', () => ({ + ClientFactory: class MockClientFactory { + async createFromUrl(): Promise<{ + sendMessageStream: typeof mockSendMessageStream + getAgentCard: typeof mockGetAgentCard + }> { + return { + sendMessageStream: mockSendMessageStream, + getAgentCard: mockGetAgentCard, + } + } + }, +})) + +const mockAgentCard: AgentCard = { + name: 'Remote Agent', + description: 'A remote agent for testing', + version: '1.0.0', + protocolVersion: '0.2.0', + url: 'http://localhost:9000', + defaultInputModes: ['text/plain'], + defaultOutputModes: ['text/plain'], + skills: [], + capabilities: {}, +} + +function createMockTaskResponse(): Task { + return { + kind: 'task', + id: 'task-1', + contextId: 'ctx-1', + status: { state: 'completed' }, + artifacts: [ + { + artifactId: 'art-1', + parts: [{ kind: 'text', text: 'Agent response' }], + }, + ], + } +} + +async function* mockStream(...events: unknown[]): AsyncGenerator { + for (const event of events) { + yield event + } +} + +async function collectStream( + gen: AsyncGenerator +): Promise<{ events: unknown[]; result: unknown }> { + const events: unknown[] = [] + let next = await gen.next() + while (!next.done) { + events.push(next.value) + next = await gen.next() + } + return { events, result: next.value } +} + +describe('A2AAgent', () => { + beforeEach(() => { + vi.clearAllMocks() + mockGetAgentCard.mockResolvedValue(mockAgentCard) + mockSendMessageStream.mockReturnValue(mockStream(createMockTaskResponse())) + }) + + describe('identity properties', () => { + it('defaults id to the URL when not provided', () => { + const agent = new A2AAgent({ url: 'http://localhost:9000' }) + expect(agent.id).toBe('http://localhost:9000') + }) + + it('uses provided id from config', () => { + const agent = new A2AAgent({ url: 'http://localhost:9000', id: 'custom-id' }) + expect(agent.id).toBe('custom-id') + }) + + it('uses provided name and description from config', () => { + const agent = new A2AAgent({ url: 'http://localhost:9000', name: 'My Agent', description: 'Does things' }) + expect(agent.name).toBe('My Agent') + expect(agent.description).toBe('Does things') + }) + + it('has undefined name and description when not provided in config', () => { + const agent = new A2AAgent({ url: 'http://localhost:9000' }) + expect(agent.name).toBeUndefined() + expect(agent.description).toBeUndefined() + }) + + it('populates name and description from agent card on first connection', async () => { + const agent = new A2AAgent({ url: 'http://localhost:9000' }) + expect(agent.name).toBeUndefined() + expect(agent.description).toBeUndefined() + + await agent.invoke('Hello') + + expect(agent.name).toBe('Remote Agent') + expect(agent.description).toBe('A remote agent for testing') + }) + + it('does not overwrite config-provided name and description with agent card values', async () => { + const agent = new A2AAgent({ + url: 'http://localhost:9000', + name: 'Custom Name', + description: 'Custom description', + }) + + await agent.invoke('Hello') + + expect(agent.name).toBe('Custom Name') + expect(agent.description).toBe('Custom description') + }) + }) + + describe('invoke', () => { + it('returns AgentResult with response text', async () => { + const agent = new A2AAgent({ url: 'http://localhost:9000' }) + + const result = await agent.invoke('Hello') + + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.role).toBe('assistant') + expect(result.lastMessage.content).toHaveLength(1) + expect(result.lastMessage.content[0]).toBeInstanceOf(TextBlock) + expect((result.lastMessage.content[0] as TextBlock).text).toBe('Agent response') + }) + + it.each([ + { desc: 'string', args: 'Hello from string', expectedText: 'Hello from string' }, + { desc: 'ContentBlock[]', args: [new TextBlock('Hello from blocks')], expectedText: 'Hello from blocks' }, + { desc: 'ContentBlockData[]', args: [{ text: 'Hello from data' }], expectedText: 'Hello from data' }, + { + desc: 'multiple ContentBlocks joined with newline', + args: [new TextBlock('Line 1'), new TextBlock('Line 2')], + expectedText: 'Line 1\nLine 2', + }, + { + desc: 'Message[] (last user message)', + args: [ + new Message({ role: 'user', content: [new TextBlock('First')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response')] }), + new Message({ role: 'user', content: [new TextBlock('Second')] }), + ], + expectedText: 'Second', + }, + { + desc: 'MessageData[] (plain objects)', + args: [{ role: 'user', content: [{ text: 'From plain data' }] }], + expectedText: 'From plain data', + }, + { + desc: 'Message[] with no user messages', + args: [new Message({ role: 'assistant', content: [new TextBlock('No user')] })], + expectedText: '', + }, + { desc: 'empty array', args: [] as TextBlock[], expectedText: '' }, + ])('sends correct parts for $desc input', async ({ args, expectedText }) => { + const agent = new A2AAgent({ url: 'http://localhost:9000' }) + + await agent.invoke(args as InvokeArgs) + + expect(mockSendMessageStream).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.objectContaining({ + parts: [{ kind: 'text', text: expectedText }], + }), + }) + ) + }) + + it('auto-connects on first invoke', async () => { + const agent = new A2AAgent({ url: 'http://localhost:9000' }) + await agent.invoke('Hello') + expect(mockGetAgentCard).toHaveBeenCalledOnce() + }) + + it('uses custom clientFactory when provided', async () => { + const customSendMessageStream = vi.fn().mockReturnValue(mockStream(createMockTaskResponse())) + const customGetAgentCard = vi.fn().mockResolvedValue(mockAgentCard) + const customCreateFromUrl = vi.fn().mockResolvedValue({ + sendMessageStream: customSendMessageStream, + getAgentCard: customGetAgentCard, + }) + const customFactory = { createFromUrl: customCreateFromUrl } + + const agent = new A2AAgent({ + url: 'http://localhost:9000', + clientFactory: customFactory as never, + }) + + await agent.invoke('Hello') + + expect(customCreateFromUrl).toHaveBeenCalledWith('http://localhost:9000', undefined) + expect(customGetAgentCard).toHaveBeenCalledOnce() + expect(customSendMessageStream).toHaveBeenCalledOnce() + // Default mock should not have been called + expect(mockSendMessageStream).not.toHaveBeenCalled() + }) + }) + + describe('stream', () => { + it('yields A2AStreamUpdateEvent for each A2A event and A2AResultEvent at the end', async () => { + const task = createMockTaskResponse() + mockSendMessageStream.mockReturnValue(mockStream(task)) + + const agent = new A2AAgent({ url: 'http://localhost:9000' }) + const { events, result } = await collectStream(agent.stream('Hello')) + + expect(events).toHaveLength(2) + expect(events[0]).toBeInstanceOf(A2AStreamUpdateEvent) + expect((events[0] as A2AStreamUpdateEvent).event).toStrictEqual(task) + expect(events[1]).toBeInstanceOf(A2AResultEvent) + expect((result as { stopReason: string }).stopReason).toBe('endTurn') + }) + + it('yields multiple A2AStreamUpdateEvents for streamed artifact chunks', async () => { + const artifactUpdate1: TaskArtifactUpdateEvent = { + kind: 'artifact-update', + taskId: 'task-1', + contextId: 'ctx-1', + artifact: { artifactId: 'art-1', parts: [{ kind: 'text', text: 'Hello ' }] }, + append: false, + } + const artifactUpdate2: TaskArtifactUpdateEvent = { + kind: 'artifact-update', + taskId: 'task-1', + contextId: 'ctx-1', + artifact: { artifactId: 'art-1', parts: [{ kind: 'text', text: 'World' }] }, + append: true, + lastChunk: true, + } + const statusUpdate: TaskStatusUpdateEvent = { + kind: 'status-update', + taskId: 'task-1', + contextId: 'ctx-1', + status: { + state: 'completed', + message: { + kind: 'message', + messageId: 'msg-1', + role: 'agent', + parts: [{ kind: 'text', text: 'Final answer' }], + }, + }, + final: true, + } + + mockSendMessageStream.mockReturnValue(mockStream(artifactUpdate1, artifactUpdate2, statusUpdate)) + const agent = new A2AAgent({ url: 'http://localhost:9000' }) + const { events } = await collectStream(agent.stream('Hello')) + + // 3 A2AStreamUpdateEvents + 1 A2AResultEvent + expect(events).toHaveLength(4) + expect(events[0]).toBeInstanceOf(A2AStreamUpdateEvent) + expect(events[1]).toBeInstanceOf(A2AStreamUpdateEvent) + expect(events[2]).toBeInstanceOf(A2AStreamUpdateEvent) + expect(events[3]).toBeInstanceOf(A2AResultEvent) + + // Final result built from last complete event (status-update with completed state) + const resultEvent = events[3] as A2AResultEvent + expect((resultEvent.result.lastMessage.content[0] as TextBlock).text).toBe('Final answer') + }) + + it('yields A2AStreamUpdateEvent for Message response', async () => { + const message: A2AMessage = { + kind: 'message', + messageId: 'msg-1', + role: 'agent', + parts: [{ kind: 'text', text: 'Direct response' }], + } + mockSendMessageStream.mockReturnValue(mockStream(message)) + + const agent = new A2AAgent({ url: 'http://localhost:9000' }) + const { events } = await collectStream(agent.stream('Hello')) + + expect(events).toHaveLength(2) + expect(events[0]).toBeInstanceOf(A2AStreamUpdateEvent) + expect((events[0] as A2AStreamUpdateEvent).event.kind).toBe('message') + + const resultEvent = events[1] as A2AResultEvent + expect((resultEvent.result.lastMessage.content[0] as TextBlock).text).toBe('Direct response') + }) + + it('builds result from status-update with completed state', async () => { + const statusUpdate: TaskStatusUpdateEvent = { + kind: 'status-update', + taskId: 'task-1', + contextId: 'ctx-1', + status: { + state: 'completed', + message: { + kind: 'message', + messageId: 'msg-1', + role: 'agent', + parts: [{ kind: 'text', text: 'Status text' }], + }, + }, + final: true, + } + mockSendMessageStream.mockReturnValue(mockStream(statusUpdate)) + + const agent = new A2AAgent({ url: 'http://localhost:9000' }) + const { events } = await collectStream(agent.stream('Hello')) + + const resultEvent = events[1] as A2AResultEvent + expect((resultEvent.result.lastMessage.content[0] as TextBlock).text).toBe('Status text') + }) + + it('returns empty text when no events are received', async () => { + mockSendMessageStream.mockReturnValue(mockStream()) + + const agent = new A2AAgent({ url: 'http://localhost:9000' }) + const { events, result } = await collectStream(agent.stream('Hello')) + + expect(events).toHaveLength(1) // only A2AResultEvent + expect(events[0]).toBeInstanceOf(A2AResultEvent) + expect((result as { lastMessage: Message }).lastMessage.content[0]).toBeInstanceOf(TextBlock) + expect(((result as { lastMessage: Message }).lastMessage.content[0] as TextBlock).text).toBe('') + }) + }) + + describe('response extraction', () => { + it('extracts text from Task response', async () => { + const agent = new A2AAgent({ url: 'http://localhost:9000' }) + const result = await agent.invoke('Hello') + expect((result.lastMessage.content[0] as TextBlock).text).toBe('Agent response') + }) + + it('extracts text from Message response', async () => { + mockSendMessageStream.mockReturnValue( + mockStream({ + kind: 'message', + messageId: 'msg-1', + role: 'agent', + parts: [{ kind: 'text', text: 'Direct response' }], + }) + ) + const agent = new A2AAgent({ url: 'http://localhost:9000' }) + const result = await agent.invoke('Hello') + expect((result.lastMessage.content[0] as TextBlock).text).toBe('Direct response') + }) + }) +}) + +describe('response text extraction via invoke', () => { + beforeEach(() => { + vi.clearAllMocks() + mockGetAgentCard.mockResolvedValue(mockAgentCard) + }) + + it('joins multiple text parts from Task artifacts', async () => { + mockSendMessageStream.mockReturnValue( + mockStream({ + kind: 'task', + id: 'task-1', + contextId: 'ctx-1', + status: { state: 'completed' }, + artifacts: [ + { + artifactId: 'art-1', + parts: [ + { kind: 'text', text: 'Part 1' }, + { kind: 'text', text: 'Part 2' }, + ], + }, + ], + } as Task) + ) + const agent = new A2AAgent({ url: 'http://localhost:9000' }) + const result = await agent.invoke('Hello') + expect((result.lastMessage.content[0] as TextBlock).text).toBe('Part 1\nPart 2') + }) + + it('joins text from multiple Task artifacts', async () => { + mockSendMessageStream.mockReturnValue( + mockStream({ + kind: 'task', + id: 'task-1', + contextId: 'ctx-1', + status: { state: 'completed' }, + artifacts: [ + { artifactId: 'art-1', parts: [{ kind: 'text', text: 'First' }] }, + { artifactId: 'art-2', parts: [{ kind: 'text', text: 'Second' }] }, + ], + } as Task) + ) + const agent = new A2AAgent({ url: 'http://localhost:9000' }) + const result = await agent.invoke('Hello') + expect((result.lastMessage.content[0] as TextBlock).text).toBe('First\nSecond') + }) + + it('falls back to Task status message when no artifacts', async () => { + mockSendMessageStream.mockReturnValue( + mockStream({ + kind: 'task', + id: 'task-1', + contextId: 'ctx-1', + status: { + state: 'completed', + message: { + kind: 'message', + messageId: 'msg-1', + role: 'agent', + parts: [{ kind: 'text', text: 'Status text' }], + }, + }, + } as Task) + ) + const agent = new A2AAgent({ url: 'http://localhost:9000' }) + const result = await agent.invoke('Hello') + expect((result.lastMessage.content[0] as TextBlock).text).toBe('Status text') + }) + + it('returns empty text for Task with no text content', async () => { + mockSendMessageStream.mockReturnValue( + mockStream({ + kind: 'task', + id: 'task-1', + contextId: 'ctx-1', + status: { state: 'completed' }, + } as Task) + ) + const agent = new A2AAgent({ url: 'http://localhost:9000' }) + const result = await agent.invoke('Hello') + expect((result.lastMessage.content[0] as TextBlock).text).toBe('') + }) + + it('extracts text from Message parts, ignoring non-text parts', async () => { + mockSendMessageStream.mockReturnValue( + mockStream({ + kind: 'message', + messageId: 'msg-1', + role: 'agent', + parts: [ + { kind: 'text', text: 'Hello' }, + { kind: 'file', file: { uri: 'file://test.txt' } }, + { kind: 'text', text: 'World' }, + ], + } as A2AMessage) + ) + const agent = new A2AAgent({ url: 'http://localhost:9000' }) + const result = await agent.invoke('Hello') + expect((result.lastMessage.content[0] as TextBlock).text).toBe('Hello\nWorld') + }) +}) diff --git a/strands-ts/src/a2a/__tests__/adapters.test.ts b/strands-ts/src/a2a/__tests__/adapters.test.ts new file mode 100644 index 0000000000..28242ba843 --- /dev/null +++ b/strands-ts/src/a2a/__tests__/adapters.test.ts @@ -0,0 +1,179 @@ +import { describe, expect, it } from 'vitest' +import { partsToContentBlocks, contentBlocksToParts } from '../adapters.js' +import { TextBlock, ToolUseBlock, ReasoningBlock } from '../../types/messages.js' +import type { ContentBlock } from '../../types/messages.js' +import { ImageBlock, VideoBlock, DocumentBlock, encodeBase64 } from '../../types/media.js' +import type { Part } from '@a2a-js/sdk' + +describe('adapters', () => { + describe('partsToContentBlocks', () => { + it('converts text parts to TextBlocks', () => { + const parts: Part[] = [ + { kind: 'text', text: 'Hello' }, + { kind: 'text', text: 'World' }, + ] + + const blocks = partsToContentBlocks(parts) + + expect(blocks).toHaveLength(2) + expect(blocks[0]).toBeInstanceOf(TextBlock) + expect((blocks[0] as TextBlock).text).toBe('Hello') + expect((blocks[1] as TextBlock).text).toBe('World') + }) + + it.each([ + { mimeType: 'image/png', BlockClass: ImageBlock, format: 'png' }, + { mimeType: 'image/jpeg', BlockClass: ImageBlock, format: 'jpeg' }, + { mimeType: 'video/mp4', BlockClass: VideoBlock, format: 'mp4' }, + { mimeType: 'application/pdf', BlockClass: DocumentBlock, format: 'pdf' }, + { mimeType: 'application/vnd.ms-excel', BlockClass: DocumentBlock, format: 'xls' }, + { mimeType: 'application/octet-stream', BlockClass: DocumentBlock, format: 'octet-stream' }, + ])( + 'converts file with bytes and MIME $mimeType to correct block with format $format', + ({ mimeType, BlockClass, format }) => { + const parts: Part[] = [{ kind: 'file', file: { bytes: encodeBase64('fake-data'), mimeType, name: 'test' } }] + + const blocks = partsToContentBlocks(parts) + + expect(blocks).toHaveLength(1) + expect(blocks[0]).toBeInstanceOf(BlockClass) + expect((blocks[0] as ImageBlock | VideoBlock | DocumentBlock).format).toBe(format) + } + ) + + it.each([ + { + desc: 'with name', + file: { uri: 'https://example.com/file.txt', name: 'readme.txt' }, + expected: '[File: readme.txt (https://example.com/file.txt)]', + }, + { + desc: 'without name (defaults to "file")', + file: { uri: 'https://example.com/file.txt' }, + expected: '[File: file (https://example.com/file.txt)]', + }, + ])('converts file with URI to TextBlock — $desc', ({ file, expected }) => { + const blocks = partsToContentBlocks([{ kind: 'file', file }]) + + expect(blocks).toHaveLength(1) + expect(blocks[0]).toBeInstanceOf(TextBlock) + expect((blocks[0] as TextBlock).text).toBe(expected) + }) + + it('converts data parts to TextBlock with JSON', () => { + const parts: Part[] = [{ kind: 'data', data: { key: 'value', count: 42 } }] + + const blocks = partsToContentBlocks(parts) + + expect(blocks).toHaveLength(1) + expect(blocks[0]).toBeInstanceOf(TextBlock) + const text = (blocks[0] as TextBlock).text + expect(text).toContain('[Structured Data]') + expect(text).toContain('"key": "value"') + }) + + it('handles mixed part types', () => { + const parts: Part[] = [ + { kind: 'text', text: 'Hello' }, + { kind: 'file', file: { uri: 'file://test.txt' } }, + { kind: 'data', data: { foo: 'bar' } }, + ] + + const blocks = partsToContentBlocks(parts) + + expect(blocks).toHaveLength(3) + expect(blocks[0]).toBeInstanceOf(TextBlock) + expect(blocks[1]).toBeInstanceOf(TextBlock) // URI file → text fallback + expect(blocks[2]).toBeInstanceOf(TextBlock) // data → text + }) + + it('returns empty array for empty input', () => { + expect(partsToContentBlocks([])).toStrictEqual([]) + }) + }) + + describe('contentBlocksToParts', () => { + it('converts text blocks to text parts', () => { + const blocks: ContentBlock[] = [new TextBlock('Hello'), new TextBlock('World')] + + expect(contentBlocksToParts(blocks)).toStrictEqual([ + { kind: 'text', text: 'Hello' }, + { kind: 'text', text: 'World' }, + ]) + }) + + it.each([ + { + desc: 'ImageBlock with bytes', + block: new ImageBlock({ format: 'png', source: { bytes: new Uint8Array([137, 80, 78, 71]) } }), + expected: { + kind: 'file', + file: { bytes: encodeBase64(new Uint8Array([137, 80, 78, 71])), mimeType: 'image/png' }, + }, + }, + { + desc: 'ImageBlock with URL', + block: new ImageBlock({ format: 'jpeg', source: { url: 'https://example.com/img.jpg' } }), + expected: { kind: 'file', file: { uri: 'https://example.com/img.jpg', mimeType: 'image/jpeg' } }, + }, + { + desc: 'VideoBlock with bytes', + block: new VideoBlock({ format: 'mp4', source: { bytes: new Uint8Array([0, 0, 0]) } }), + expected: { kind: 'file', file: { bytes: encodeBase64(new Uint8Array([0, 0, 0])), mimeType: 'video/mp4' } }, + }, + { + desc: 'DocumentBlock with bytes', + block: new DocumentBlock({ name: 'doc.pdf', format: 'pdf', source: { bytes: new Uint8Array([37, 80]) } }), + expected: { + kind: 'file', + file: { bytes: encodeBase64(new Uint8Array([37, 80])), mimeType: 'application/pdf', name: 'doc.pdf' }, + }, + }, + { + desc: 'DocumentBlock with text source', + block: new DocumentBlock({ name: 'readme', format: 'txt', source: { text: 'Hello doc' } }), + expected: { kind: 'text', text: 'Hello doc' }, + }, + ])('converts $desc to file part', ({ block, expected }) => { + expect(contentBlocksToParts([block])).toStrictEqual([expected]) + }) + + it('handles mixed text and media blocks', () => { + const blocks: ContentBlock[] = [ + new TextBlock('Caption'), + new ImageBlock({ format: 'png', source: { bytes: new Uint8Array([1, 2]) } }), + new TextBlock('End'), + ] + + const parts = contentBlocksToParts(blocks) + + expect(parts).toHaveLength(3) + expect(parts[0]).toStrictEqual({ kind: 'text', text: 'Caption' }) + expect(parts[1]).toStrictEqual({ + kind: 'file', + file: { bytes: encodeBase64(new Uint8Array([1, 2])), mimeType: 'image/png' }, + }) + expect(parts[2]).toStrictEqual({ kind: 'text', text: 'End' }) + }) + + it('skips unsupported block types', () => { + const blocks: ContentBlock[] = [ + new TextBlock('Hello'), + new ToolUseBlock({ name: 'test', toolUseId: 'id-1', input: {} }), + new ReasoningBlock({ text: 'thinking' }), + ] + + expect(contentBlocksToParts(blocks)).toStrictEqual([{ kind: 'text', text: 'Hello' }]) + }) + + it.each([ + { desc: 'empty input', blocks: [] as ContentBlock[] }, + { + desc: 'no convertible blocks', + blocks: [new ToolUseBlock({ name: 'test', toolUseId: 'id-1', input: {} })] as ContentBlock[], + }, + ])('returns empty array for $desc', ({ blocks }) => { + expect(contentBlocksToParts(blocks)).toStrictEqual([]) + }) + }) +}) diff --git a/strands-ts/src/a2a/__tests__/events.test.ts b/strands-ts/src/a2a/__tests__/events.test.ts new file mode 100644 index 0000000000..42bc2f7b6b --- /dev/null +++ b/strands-ts/src/a2a/__tests__/events.test.ts @@ -0,0 +1,78 @@ +import { describe, expect, it } from 'vitest' +import { A2AStreamUpdateEvent, A2AResultEvent } from '../events.js' +import { AgentResult } from '../../types/agent.js' +import { Message, TextBlock } from '../../types/messages.js' +import { AgentMetrics } from '../../telemetry/meter.js' +import type { A2AEventData } from '../events.js' + +describe('A2AStreamUpdateEvent', () => { + it('creates instance with correct properties', () => { + const eventData = { kind: 'status-update', taskId: 'task-1', status: { state: 'working' } } as A2AEventData + const event = new A2AStreamUpdateEvent(eventData) + + expect(event.type).toBe('a2aStreamUpdateEvent') + expect(event.event).toBe(eventData) + }) + + describe('toJSON', () => { + const event = new A2AStreamUpdateEvent({ + kind: 'status-update', + taskId: 'task-1', + status: { state: 'working' }, + } as A2AEventData) + + it('serializes', () => { + expect(JSON.parse(JSON.stringify(event))).toStrictEqual({ + type: 'a2aStreamUpdateEvent', + event: { kind: 'status-update', taskId: 'task-1', status: { state: 'working' } }, + }) + }) + + it('only excludes expected fields', () => { + const json = event.toJSON() + expect(Object.keys(event).filter((k) => !(k in json))).toStrictEqual([]) + }) + }) +}) + +describe('A2AResultEvent', () => { + it('creates instance with correct properties', () => { + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: new Message({ role: 'assistant', content: [new TextBlock('Done')] }), + metrics: new AgentMetrics(), + invocationState: {}, + }) + const event = new A2AResultEvent({ result }) + + expect(event.type).toBe('a2aResultEvent') + expect(event.result).toBe(result) + }) + + describe('toJSON', () => { + const event = new A2AResultEvent({ + result: new AgentResult({ + stopReason: 'endTurn', + lastMessage: new Message({ role: 'assistant', content: [new TextBlock('Done')] }), + metrics: new AgentMetrics(), + invocationState: {}, + }), + }) + + it('serializes', () => { + expect(JSON.parse(JSON.stringify(event))).toStrictEqual({ + type: 'a2aResultEvent', + result: { + type: 'agentResult', + stopReason: 'endTurn', + lastMessage: { role: 'assistant', content: [{ text: 'Done' }] }, + }, + }) + }) + + it('only excludes expected fields', () => { + const json = event.toJSON() + expect(Object.keys(event).filter((k) => !(k in json))).toStrictEqual([]) + }) + }) +}) diff --git a/strands-ts/src/a2a/__tests__/executor.test.ts b/strands-ts/src/a2a/__tests__/executor.test.ts new file mode 100644 index 0000000000..a20aba87ba --- /dev/null +++ b/strands-ts/src/a2a/__tests__/executor.test.ts @@ -0,0 +1,252 @@ +import { describe, expect, it, vi } from 'vitest' +import { A2AExecutor } from '../executor.js' +import type { AgentExecutionEvent, ExecutionEventBus, RequestContext } from '@a2a-js/sdk/server' +import type { TaskArtifactUpdateEvent, TaskStatusUpdateEvent } from '@a2a-js/sdk' +import { Agent } from '../../agent/agent.js' +import type { InvokableAgent } from '../../types/agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { createMockAgent } from '../../__fixtures__/agent-helpers.js' +import { TextBlock } from '../../types/messages.js' +import { ImageBlock, encodeBase64 } from '../../types/media.js' +import { ContentBlockEvent, ModelStreamUpdateEvent } from '../../hooks/events.js' +import { AgentResult } from '../../types/agent.js' +import { Message } from '../../types/messages.js' + +function createMockEventBus(): ExecutionEventBus & { events: AgentExecutionEvent[] } { + const events: AgentExecutionEvent[] = [] + return { + events, + publish: vi.fn((event) => { + events.push(event) + }), + on: vi.fn().mockReturnThis(), + off: vi.fn().mockReturnThis(), + once: vi.fn().mockReturnThis(), + removeAllListeners: vi.fn().mockReturnThis(), + finished: vi.fn(), + } +} + +function createRequestContext(text: string, taskId: string = 'task-1'): RequestContext { + return { + taskId, + contextId: 'ctx-1', + userMessage: { + kind: 'message', + messageId: 'msg-1', + role: 'user', + parts: [{ kind: 'text', text }], + }, + } +} + +describe('A2AExecutor', () => { + describe('execute', () => { + it('streams text deltas as artifact chunks and publishes completed status', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Agent response' }) + const agent = new Agent({ model, printer: false }) + const executor = new A2AExecutor(agent) + const eventBus = createMockEventBus() + + await executor.execute(createRequestContext('Hello agent'), eventBus) + + // First event registers the task with the ResultManager + expect(eventBus.events[0]).toStrictEqual({ + kind: 'task', + id: 'task-1', + contextId: 'ctx-1', + status: { state: 'working' }, + }) + + const artifactEvents = eventBus.events.filter((e): e is TaskArtifactUpdateEvent => e.kind === 'artifact-update') + const statusEvents = eventBus.events.filter((e): e is TaskStatusUpdateEvent => e.kind === 'status-update') + + // Should have at least 2 artifact events (text delta + lastChunk) + expect(artifactEvents.length).toBeGreaterThanOrEqual(2) + + // First artifact: text delta, creates new artifact + expect(artifactEvents[0]).toStrictEqual({ + kind: 'artifact-update', + taskId: 'task-1', + contextId: 'ctx-1', + append: false, + artifact: { artifactId: expect.any(String), parts: [{ kind: 'text', text: 'Agent response' }] }, + }) + + // Last artifact: lastChunk marker, appends to existing artifact + expect(artifactEvents[artifactEvents.length - 1]).toStrictEqual( + expect.objectContaining({ append: true, lastChunk: true }) + ) + + // All artifact events share the same artifactId + const artifactId = artifactEvents[0]!.artifact.artifactId + for (const event of artifactEvents) { + expect(event.artifact.artifactId).toBe(artifactId) + } + + // Only completed status — no working status (A2A-compliant streaming) + expect(statusEvents).toStrictEqual([ + { kind: 'status-update', taskId: 'task-1', contextId: 'ctx-1', status: { state: 'completed' }, final: true }, + ]) + }) + + it('sets append to true for subsequent chunks after the first', async () => { + const model = new MockMessageModel().addTurn([ + { type: 'textBlock', text: 'First' }, + { type: 'textBlock', text: 'Second' }, + ]) + const agent = new Agent({ model, printer: false }) + const executor = new A2AExecutor(agent) + const eventBus = createMockEventBus() + + await executor.execute(createRequestContext('Hello'), eventBus) + + const artifactEvents = eventBus.events.filter((e): e is TaskArtifactUpdateEvent => e.kind === 'artifact-update') + + // 2 text deltas + 1 lastChunk + expect(artifactEvents).toHaveLength(3) + expect(artifactEvents.map((e) => e.append)).toStrictEqual([false, true, true]) + expect(artifactEvents[2]!.lastChunk).toBe(true) + }) + + it('converts A2A parts to content blocks and passes to stream', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Response' }) + const agent = new Agent({ model, printer: false }) + vi.spyOn(agent, 'stream') + const executor = new A2AExecutor(agent) + const eventBus = createMockEventBus() + + const context: RequestContext = { + taskId: 'task-1', + contextId: 'ctx-1', + userMessage: { + kind: 'message', + messageId: 'msg-1', + role: 'user', + parts: [ + { kind: 'text', text: 'Line 1' }, + { kind: 'file', file: { uri: 'file://test.txt' } }, + { kind: 'text', text: 'Line 2' }, + ], + }, + } + + await executor.execute(context, eventBus) + + expect(agent.stream).toHaveBeenCalledWith( + [new TextBlock('Line 1'), new TextBlock('[File: file (file://test.txt)]'), new TextBlock('Line 2')], + { invocationState: { a2aRequestContext: context } } + ) + }) + + it('forwards the A2A request context to the agent via invocationState', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Response' }) + const agent = new Agent({ model, printer: false }) + const streamSpy = vi.spyOn(agent, 'stream') + const executor = new A2AExecutor(agent) + const eventBus = createMockEventBus() + const context = createRequestContext('hello', 'task-42') + + await executor.execute(context, eventBus) + + expect(streamSpy).toHaveBeenCalledTimes(1) + const [, options] = streamSpy.mock.calls[0]! + expect(options?.invocationState).toEqual({ a2aRequestContext: context }) + }) + + it('re-throws when agent throws, publishing only the initial task event', async () => { + const model = new MockMessageModel().addTurn(new Error('Agent failed')) + const agent = new Agent({ model, printer: false }) + const executor = new A2AExecutor(agent) + const eventBus = createMockEventBus() + + await expect(executor.execute(createRequestContext('Hello'), eventBus)).rejects.toThrow('Agent failed') + + // Only the initial task registration event is published before the error + expect(eventBus.events).toStrictEqual([ + { kind: 'task', id: 'task-1', contextId: 'ctx-1', status: { state: 'working' } }, + ]) + }) + + it('publishes image content blocks as separate file artifacts', async () => { + const imageBytes = new Uint8Array([137, 80, 78, 71]) + const mockAgent: InvokableAgent = { + id: 'test-agent', + name: 'Test Agent', + invoke: vi.fn(), + async *stream() { + const agent = createMockAgent() + // Text delta + yield new ModelStreamUpdateEvent({ + agent, + event: { type: 'modelContentBlockDeltaEvent', delta: { type: 'textDelta', text: 'Here is the image:' } }, + invocationState: {}, + }) + // Image content block + yield new ContentBlockEvent({ + agent, + contentBlock: new ImageBlock({ format: 'png', source: { bytes: imageBytes } }), + invocationState: {}, + }) + return new AgentResult({ + stopReason: 'endTurn', + lastMessage: new Message({ role: 'assistant', content: [new TextBlock('Here is the image:')] }), + invocationState: {}, + }) + }, + } + + const executor = new A2AExecutor(mockAgent) + const eventBus = createMockEventBus() + + await executor.execute(createRequestContext('Generate an image'), eventBus) + + const artifactEvents = eventBus.events.filter((e): e is TaskArtifactUpdateEvent => e.kind === 'artifact-update') + + // text delta + image artifact + final text lastChunk = 3 + expect(artifactEvents).toHaveLength(3) + + // First: text delta + expect(artifactEvents[0]!.artifact.parts).toStrictEqual([{ kind: 'text', text: 'Here is the image:' }]) + + // Second: image as file part with its own artifactId + expect(artifactEvents[1]!.artifact.artifactId).not.toBe(artifactEvents[0]!.artifact.artifactId) + expect(artifactEvents[1]!.lastChunk).toBe(true) + expect(artifactEvents[1]!.append).toBe(false) + expect(artifactEvents[1]!.artifact.parts).toStrictEqual([ + { kind: 'file', file: { bytes: encodeBase64(imageBytes), mimeType: 'image/png' } }, + ]) + + // Third: final text lastChunk + expect(artifactEvents[2]!.lastChunk).toBe(true) + }) + + it('throws A2AError.invalidRequest when parts produce no content blocks', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Response' }) + const agent = new Agent({ model, printer: false }) + const executor = new A2AExecutor(agent) + const eventBus = createMockEventBus() + + const context: RequestContext = { + taskId: 'task-1', + contextId: 'ctx-1', + userMessage: { kind: 'message', messageId: 'msg-1', role: 'user', parts: [] }, + } + + await expect(executor.execute(context, eventBus)).rejects.toThrow('No content blocks available') + expect(eventBus.events).toStrictEqual([]) + }) + }) + + describe('cancelTask', () => { + it('throws A2AError.unsupportedOperation', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: '' }) + const agent = new Agent({ model, printer: false }) + const executor = new A2AExecutor(agent) + const eventBus = createMockEventBus() + + await expect(executor.cancelTask('task-1', eventBus)).rejects.toThrow('Task cancellation is not supported') + expect(eventBus.events).toStrictEqual([]) + }) + }) +}) diff --git a/strands-ts/src/a2a/__tests__/server.test.node.ts b/strands-ts/src/a2a/__tests__/server.test.node.ts new file mode 100644 index 0000000000..093f69dd0f --- /dev/null +++ b/strands-ts/src/a2a/__tests__/server.test.node.ts @@ -0,0 +1,128 @@ +import { describe, expect, it, vi } from 'vitest' +import { A2AExpressServer, type A2AExpressServerConfig } from '../express-server.js' +import { A2AServer } from '../server.js' +import { Agent } from '../../agent/agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' + +// Mock express +vi.mock('express', () => { + const mockRouter = { + get: vi.fn(), + post: vi.fn(), + use: vi.fn(), + } + const mockApp = { + use: vi.fn(), + listen: vi.fn((_port: number, _host: string, cb: () => void) => { + cb() + return { on: vi.fn(), close: vi.fn(), address: () => ({ port: _port || 54321 }) } + }), + } + const express = Object.assign( + vi.fn(() => mockApp), + { + Router: vi.fn(() => mockRouter), + json: vi.fn(() => 'json-middleware'), + } + ) + return { default: express } +}) + +// Mock A2A SDK express middleware +const mockAgentCardHandler = vi.fn(() => 'agent-card-handler') +const mockJsonRpcHandler = vi.fn(() => 'json-rpc-handler') + +vi.mock('@a2a-js/sdk/server/express', () => ({ + agentCardHandler: (...args: Parameters) => mockAgentCardHandler(...args), + jsonRpcHandler: (...args: Parameters) => mockJsonRpcHandler(...args), + UserBuilder: { noAuthentication: vi.fn() }, +})) + +function createTestConfig(overrides?: Partial): A2AExpressServerConfig { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + return { + agent: new Agent({ model, printer: false }), + name: 'Test Agent', + ...overrides, + } +} + +describe('A2AExpressServer', () => { + describe('constructor', () => { + it('builds agent card with default values', () => { + const server = new A2AExpressServer(createTestConfig()) + + expect(server.agentCard).toStrictEqual({ + name: 'Test Agent', + description: '', + version: '0.0.1', + protocolVersion: '0.2.0', + url: 'http://127.0.0.1:9000', + defaultInputModes: ['text/plain'], + defaultOutputModes: ['text/plain'], + skills: [], + capabilities: { streaming: true }, + }) + }) + + it('uses custom config values', () => { + const server = new A2AExpressServer( + createTestConfig({ + description: 'A helpful agent', + host: '0.0.0.0', + port: 8080, + version: '1.0.0', + skills: [{ id: 'skill-1', name: 'Skill 1', description: 'A skill', tags: [] }], + }) + ) + + expect(server.agentCard).toStrictEqual({ + name: 'Test Agent', + description: 'A helpful agent', + version: '1.0.0', + protocolVersion: '0.2.0', + url: 'http://0.0.0.0:8080', + defaultInputModes: ['text/plain'], + defaultOutputModes: ['text/plain'], + skills: [{ id: 'skill-1', name: 'Skill 1', description: 'A skill', tags: [] }], + capabilities: { streaming: true }, + }) + }) + + it('uses httpUrl override when provided', () => { + const server = new A2AExpressServer(createTestConfig({ httpUrl: 'https://my-agent.example.com' })) + + expect(server.agentCard.url).toBe('https://my-agent.example.com') + }) + + it('accepts custom taskStore', () => { + const taskStore = { save: vi.fn(), load: vi.fn() } + const server = new A2AExpressServer(createTestConfig({ taskStore })) + expect(server.agentCard).toBeDefined() + }) + + it('is an instance of A2AServer', () => { + const server = new A2AExpressServer(createTestConfig()) + expect(server).toBeInstanceOf(A2AServer) + }) + }) + + describe('createMiddleware', () => { + it('returns an express router with SDK middleware', async () => { + const server = new A2AExpressServer(createTestConfig()) + const router = server.createMiddleware() + + expect(router).toBeDefined() + expect(router.use).toHaveBeenCalledTimes(2) + expect(router.use).toHaveBeenCalledWith('/.well-known/agent-card.json', 'agent-card-handler') + expect(router.use).toHaveBeenCalledWith('/', 'json-rpc-handler') + expect(mockAgentCardHandler).toHaveBeenCalledWith({ + agentCardProvider: expect.objectContaining({ getAgentCard: expect.any(Function) }), + }) + expect(mockJsonRpcHandler).toHaveBeenCalledWith({ + requestHandler: expect.anything(), + userBuilder: expect.anything(), + }) + }) + }) +}) diff --git a/strands-ts/src/a2a/__tests__/server.test.ts b/strands-ts/src/a2a/__tests__/server.test.ts new file mode 100644 index 0000000000..b8294d29fb --- /dev/null +++ b/strands-ts/src/a2a/__tests__/server.test.ts @@ -0,0 +1,55 @@ +import { describe, expect, it, vi } from 'vitest' +import { A2AServer } from '../server.js' +import { Agent } from '../../agent/agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' + +describe('A2AServer', () => { + describe('constructor', () => { + it('builds agent card with provided values', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const server = new A2AServer({ + agent: new Agent({ model, printer: false }), + name: 'Base Agent', + description: 'A base agent', + httpUrl: 'http://example.com', + version: '2.0.0', + }) + + expect(server.agentCard).toStrictEqual({ + name: 'Base Agent', + description: 'A base agent', + version: '2.0.0', + protocolVersion: '0.2.0', + url: 'http://example.com', + defaultInputModes: ['text/plain'], + defaultOutputModes: ['text/plain'], + skills: [], + capabilities: { streaming: true }, + }) + }) + + it('uses default values when optional config is omitted', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const server = new A2AServer({ + agent: new Agent({ model, printer: false }), + name: 'Minimal Agent', + }) + + expect(server.agentCard.description).toBe('') + expect(server.agentCard.version).toBe('0.0.1') + expect(server.agentCard.url).toBe('') + expect(server.agentCard.skills).toStrictEqual([]) + }) + + it('accepts custom taskStore', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const taskStore = { save: vi.fn(), load: vi.fn() } + const server = new A2AServer({ + agent: new Agent({ model, printer: false }), + name: 'Agent', + taskStore, + }) + expect(server.agentCard).toBeDefined() + }) + }) +}) diff --git a/strands-ts/src/a2a/a2a-agent.ts b/strands-ts/src/a2a/a2a-agent.ts new file mode 100644 index 0000000000..160ceb38af --- /dev/null +++ b/strands-ts/src/a2a/a2a-agent.ts @@ -0,0 +1,299 @@ +/** + * A2A agent that wraps a remote A2A agent as an InvokableAgent. + * + * Implements the InvokableAgent interface so it can be used anywhere a local Agent + * can be used. The remote agent is invoked via the A2A protocol. + * The A2A protocol is experimental, so breaking changes in the underlying SDK + * may require breaking changes in this module. + */ + +import type { AgentCard, Part } from '@a2a-js/sdk' +import type { Client as A2AClientSdk, ClientFactory as ClientFactoryType } from '@a2a-js/sdk/client' +import { ClientFactory } from '@a2a-js/sdk/client' +import type { InvocationState, InvokableAgent, InvokeArgs, InvokeOptions } from '../types/agent.js' +import { AgentResult } from '../types/agent.js' +import { Message, TextBlock, type ContentBlock, type ContentBlockData, type MessageData } from '../types/messages.js' +import { A2AStreamUpdateEvent, A2AResultEvent, type A2AEventData, type A2AStreamEvent } from './events.js' +import { logger } from '../logging/logger.js' +import { logExperimentalWarning } from './logging.js' + +/** + * Configuration options for creating an A2AAgent. + */ +export interface A2AAgentConfig { + /** Base URL of the remote A2A agent */ + url: string + /** Path to the agent card endpoint (default: '/.well-known/agent-card.json') */ + agentCardPath?: string + /** Optional unique identifier. Defaults to the URL. */ + id?: string + /** Optional name. If not provided, populated from the agent card after connection. */ + name?: string + /** Optional description. If not provided, populated from the agent card after connection. */ + description?: string + /** Optional custom A2A ClientFactory for authenticating requests (e.g. SigV4, bearer token). */ + clientFactory?: ClientFactoryType +} + +/** + * Wraps a remote A2A agent as an InvokableAgent. + * + * Implements `InvokableAgent` so it can be used polymorphically with local `Agent` instances. + * On invocation, the agent lazily connects to the remote endpoint via the A2A protocol + * and returns the response as an `AgentResult`. + * + * @example + * ```typescript + * import { A2AAgent } from '@strands-agents/sdk/a2a' + * + * const remoteAgent = new A2AAgent({ url: 'http://localhost:9000' }) + * const result = await remoteAgent.invoke('Hello, remote agent!') + * console.log(result.toString()) + * ``` + */ +export class A2AAgent implements InvokableAgent { + private _config: A2AAgentConfig + private _client: A2AClientSdk | undefined + private _agentCard: AgentCard | undefined + + /** + * The unique identifier of the agent instance. + */ + readonly id: string + + /** + * The name of the agent. + * If not provided in config, populated from the agent card after connection. + */ + readonly name?: string + + /** + * Optional description of what the agent does. + * If not provided in config, populated from the agent card after connection. + */ + readonly description?: string + + /** + * Creates a new A2AAgent. + * + * @param config - Configuration for connecting to the remote agent + */ + constructor(config: A2AAgentConfig) { + this._config = config + this.id = config.id ?? config.url + if (config.name !== undefined) this.name = config.name + if (config.description !== undefined) this.description = config.description + } + + /** + * Invokes the remote agent and returns the final result. + * + * Built on top of `stream()` — consumes the full event stream and returns the final result. + * + * @param args - Arguments for invoking the agent + * @param options - Optional invocation options. See {@link stream} for behavior. + * @returns Promise that resolves to the AgentResult + */ + async invoke(args: InvokeArgs, options?: InvokeOptions): Promise { + const gen = this.stream(args, options) + let next = await gen.next() + while (!next.done) { + next = await gen.next() + } + return next.value + } + + /** + * Streams the remote agent execution, yielding A2A events as they arrive. + * + * Yields `A2AStreamUpdateEvent` for each raw A2A protocol event (Message, Task, + * TaskStatusUpdateEvent, TaskArtifactUpdateEvent), followed by an `A2AResultEvent` + * containing the final result built from the last complete event. + * + * @param args - Arguments for invoking the agent + * @param options - Optional invocation options. If `invocationState` is + * provided, it is returned on the resulting `AgentResult`. The remote + * agent runs in another process and cannot read or mutate it. Other + * fields on `options` are ignored. + * @returns Async generator that yields AgentStreamEvent objects and returns AgentResult + */ + async *stream(args: InvokeArgs, options?: InvokeOptions): AsyncGenerator { + const client = await this._getClient() + const text = this._extractTextFromArgs(args) + const invocationState = options?.invocationState ?? {} + + let lastEvent: A2AEventData | undefined + let lastCompleteEvent: A2AEventData | undefined + const artifactTexts = new Map() + + const eventStream = client.sendMessageStream({ + message: { + kind: 'message', + messageId: globalThis.crypto.randomUUID(), + role: 'user', + parts: [{ kind: 'text', text }], + }, + }) + + for await (const event of eventStream) { + lastEvent = event + if (this._isCompleteEvent(event)) { + lastCompleteEvent = event + } + if (event.kind === 'artifact-update') { + const id = event.artifact.artifactId + if (!event.append) { + artifactTexts.set(id, []) + } + const chunks = artifactTexts.get(id) ?? [] + const chunkText = this._textFromParts(event.artifact.parts) + if (chunkText) { + chunks.push(chunkText) + artifactTexts.set(id, chunks) + } + } + yield new A2AStreamUpdateEvent(event) + } + + const finalEvent = lastCompleteEvent ?? lastEvent + const accumulatedText = [...artifactTexts.values()].map((chunks) => chunks.join('')).join('\n') + const result = this._buildResult(finalEvent, invocationState, accumulatedText) + + yield new A2AResultEvent({ result }) + return result + } + + /** + * Returns the cached A2A SDK client, creating one lazily on first use. + * Also fetches and caches the agent card for name/description. + * + * @returns The A2A SDK client + */ + private async _getClient(): Promise { + if (this._client) { + return this._client + } + + logExperimentalWarning() + + const factory = this._config.clientFactory ?? new ClientFactory() + const client = await factory.createFromUrl(this._config.url, this._config.agentCardPath) + this._agentCard = await client.getAgentCard() + if (this.name === undefined && this._agentCard?.name) { + ;(this as { name?: string }).name = this._agentCard.name + } + if (this.description === undefined && this._agentCard?.description) { + ;(this as { description?: string }).description = this._agentCard.description + } + this._client = client + return client + } + + /** + * Extracts a text string from InvokeArgs for sending to the remote agent. + * + * @param args - The invocation arguments + * @returns The extracted text string + */ + private _extractTextFromArgs(args: InvokeArgs): string { + if (typeof args === 'string') return args + if (!Array.isArray(args) || args.length === 0) return '' + + // Message[] or MessageData[] — find last user message's content + if ('role' in args[0]!) { + const messages = args as (Message | MessageData)[] + const lastUser = messages + .slice() + .reverse() + .find((m) => m.role === 'user') + if (!lastUser) return '' + args = lastUser instanceof Message ? lastUser.content : (lastUser.content as ContentBlockData[]) + } + + // ContentBlock[] or ContentBlockData[] — join text from all text blocks + const blocks = args as (ContentBlock | ContentBlockData)[] + const nonTextCount = blocks.filter((b) => ('type' in b ? b.type !== 'textBlock' : !('text' in b))).length + if (nonTextCount > 0) { + logger.warn( + `non_text_blocks=<${nonTextCount}> | stripping non-text content blocks, A2AAgent does not yet support non-text content` + ) + } + + return blocks + .filter((b): b is TextBlock => ('type' in b ? b.type === 'textBlock' : 'text' in b)) + .map((b) => b.text) + .join('\n') + } + + /** + * Checks whether an A2A streaming event represents a complete response. + * + * @param event - The A2A streaming event + * @returns True if the event is a terminal/complete event + */ + private _isCompleteEvent(event: A2AEventData): boolean { + if (event.kind === 'message') return true + if (event.kind === 'task') return true + if (event.kind === 'artifact-update') return event.lastChunk === true + if (event.kind === 'status-update') return event.status.state === 'completed' + return false + } + + /** + * Builds an AgentResult from the final A2A streaming event. + * + * @param event - The final A2A event, or undefined if no events were received + * @param invocationState - Caller-provided invocation state, threaded through to the result + * @param accumulatedText - Optional accumulated text from streaming artifacts + * @returns The constructed AgentResult + */ + private _buildResult( + event: A2AEventData | undefined, + invocationState: InvocationState, + accumulatedText?: string + ): AgentResult { + const text = this._extractTextFromEvent(event) || accumulatedText || '' + const lastMessage = new Message({ + role: 'assistant', + content: [new TextBlock(text)], + }) + return new AgentResult({ stopReason: 'endTurn', lastMessage, invocationState }) + } + + /** + * Extracts text content from an A2A streaming event. + * + * @param event - The A2A streaming event + * @returns Extracted text content + */ + private _extractTextFromEvent(event: A2AEventData | undefined): string { + if (!event) return '' + if (event.kind === 'message') { + return this._textFromParts(event.parts) + } + if (event.kind === 'task') { + const parts = event.artifacts?.flatMap((a) => a.parts) ?? [] + return this._textFromParts(parts) || this._textFromParts(event.status?.message?.parts ?? []) + } + if (event.kind === 'artifact-update') { + return this._textFromParts(event.artifact.parts) + } + if (event.kind === 'status-update' && event.status.message) { + return this._textFromParts(event.status.message.parts) + } + return '' + } + + /** + * Joins text from A2A parts, filtering out non-text parts. + * + * @param parts - Array of A2A parts + * @returns Joined text content + */ + private _textFromParts(parts: Part[]): string { + return parts + .filter((p): p is Part & { kind: 'text'; text: string } => p.kind === 'text') + .map((p) => p.text) + .join('\n') + } +} diff --git a/strands-ts/src/a2a/adapters.ts b/strands-ts/src/a2a/adapters.ts new file mode 100644 index 0000000000..2eaa27ab8b --- /dev/null +++ b/strands-ts/src/a2a/adapters.ts @@ -0,0 +1,198 @@ +/** + * Conversion utilities between Strands SDK content blocks and A2A protocol parts. + * + * Supports text, images, videos, documents, and structured data. + */ + +import type { Part, FileWithBytes, FileWithUri } from '@a2a-js/sdk' +import type { ContentBlock } from '../types/messages.js' +import { TextBlock } from '../types/messages.js' +import type { ImageFormat, DocumentFormat, VideoFormat } from '../mime.js' +import { toMimeType, toMediaFormat } from '../mime.js' +import { ImageBlock, VideoBlock, DocumentBlock, decodeBase64, encodeBase64 } from '../types/media.js' +import { logger } from '../logging/logger.js' + +/** + * Converts A2A protocol parts to Strands SDK content blocks. + * + * Handles text, file (image/video/document), and structured data parts, + * @param parts - Array of A2A protocol parts + * @returns Array of Strands content blocks + */ +export function partsToContentBlocks(parts: Part[]): ContentBlock[] { + const blocks: ContentBlock[] = [] + + for (const part of parts) { + try { + switch (part.kind) { + case 'text': + blocks.push(new TextBlock(part.text)) + break + case 'file': + blocks.push(_convertFilePart(part.file)) + break + case 'data': + blocks.push(new TextBlock(`[Structured Data]\n${JSON.stringify(part.data, null, 2)}`)) + break + } + } catch { + logger.warn(`part_kind=<${part.kind}> | failed to convert A2A part to content block`) + } + } + + return blocks +} + +/** + * Converts Strands SDK content blocks to A2A protocol parts. + * + * Supports text, image, video, and document blocks. Image and video blocks + * with byte sources are encoded as base64 file parts; URL-based sources + * become URI file parts. Unsupported block types are silently skipped. + * + * @param blocks - Array of Strands content blocks + * @returns Array of A2A parts + */ +export function contentBlocksToParts(blocks: ContentBlock[]): Part[] { + const parts: Part[] = [] + + for (const block of blocks) { + switch (block.type) { + case 'textBlock': + parts.push({ kind: 'text', text: block.text }) + break + case 'imageBlock': + case 'videoBlock': { + const filePart = _mediaBlockToFilePart(block) + if (filePart) parts.push(filePart) + break + } + case 'documentBlock': { + const filePart = _documentBlockToFilePart(block) + if (filePart) parts.push(filePart) + break + } + } + } + + return parts +} + +/** + * Converts an A2A FilePart to the appropriate Strands content block. + * + * @param file - The file object from a FilePart (either bytes or URI based) + * @returns ContentBlock for the file + */ +function _convertFilePart(file: FileWithBytes | FileWithUri): ContentBlock { + if ('bytes' in file) { + const decoded = decodeBase64(file.bytes) + const fileType = _getFileType(file.mimeType) + const format = _getFormat(file.mimeType, fileType) + + if (fileType === 'image') { + return new ImageBlock({ format: format as ImageFormat, source: { bytes: decoded } }) + } + + if (fileType === 'video') { + return new VideoBlock({ format: format as VideoFormat, source: { bytes: decoded } }) + } + + // Document or unknown — treat as document + return new DocumentBlock({ + name: file.name ?? 'document', + format: format as DocumentFormat, + source: { bytes: decoded }, + }) + } + + const name = file.name ?? 'file' + return new TextBlock(`[File: ${name} (${file.uri})]`) +} + +/** + * Classifies a MIME type into a file category. + * + * @param mimeType - The MIME type string + * @returns The file type category + */ +function _getFileType(mimeType: string | undefined): 'image' | 'video' | 'document' | 'unknown' { + if (!mimeType) { + return 'unknown' + } + + const lower = mimeType.toLowerCase() + if (lower.startsWith('image/')) return 'image' + if (lower.startsWith('video/')) return 'video' + if (lower.startsWith('text/') || lower.startsWith('application/')) return 'document' + return 'unknown' +} + +/** + * Resolves a MIME type to a Strands media format using the reverse MIME_TYPES lookup. + * Falls back to the MIME subtype for unrecognized types. + * + * @param mimeType - The MIME type string + * @param fileType - The classified file type + * @returns The format string + */ +function _getFormat(mimeType: string | undefined, fileType: string): string { + if (!mimeType) { + return fileType === 'image' ? 'png' : fileType === 'video' ? 'mp4' : 'txt' + } + + const lower = mimeType.toLowerCase() + + // Use the reverse lookup (handles complex types like application/vnd.ms-excel → xls) + const known = toMediaFormat(lower) + if (known) { + return known + } + + // Fallback: extract subtype from MIME (e.g., image/tiff → tiff) + if (lower.includes('/')) { + return lower.split('/').pop()! + } + + return 'txt' +} + +/** + * Converts an ImageBlock or VideoBlock to an A2A FilePart. + * + * @param block - The image or video block + * @returns A2A FilePart, or undefined if the source type is unsupported + */ +function _mediaBlockToFilePart(block: ImageBlock | VideoBlock): Part | undefined { + const mimeType = toMimeType(block.format)! + + if (block.source.type === 'imageSourceBytes' || block.source.type === 'videoSourceBytes') { + return { kind: 'file', file: { bytes: encodeBase64(block.source.bytes), mimeType } } + } + + if (block.source.type === 'imageSourceUrl') { + return { kind: 'file', file: { uri: block.source.url, mimeType } } + } + + return undefined +} + +/** + * Converts a DocumentBlock to an A2A FilePart. + * + * @param block - The document block + * @returns A2A FilePart, or undefined if the source type is unsupported + */ +function _documentBlockToFilePart(block: DocumentBlock): Part | undefined { + const mimeType = toMimeType(block.format)! + + if (block.source.type === 'documentSourceBytes') { + return { kind: 'file', file: { bytes: encodeBase64(block.source.bytes), mimeType, name: block.name } } + } + + if (block.source.type === 'documentSourceText') { + return { kind: 'text', text: block.source.text } + } + + return undefined +} diff --git a/strands-ts/src/a2a/events.ts b/strands-ts/src/a2a/events.ts new file mode 100644 index 0000000000..fd27c1fc78 --- /dev/null +++ b/strands-ts/src/a2a/events.ts @@ -0,0 +1,62 @@ +/** + * A2A-specific stream events yielded by A2AAgent.stream(). + */ + +import type { Message, Task, TaskStatusUpdateEvent, TaskArtifactUpdateEvent } from '@a2a-js/sdk' +import { StreamEvent } from '../hooks/events.js' +import type { AgentResultEvent } from '../hooks/events.js' + +/** + * Union of raw A2A protocol event types received during streaming. + */ +export type A2AEventData = Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + +/** + * Event wrapping a raw A2A protocol streaming event. + * + * Yielded by `A2AAgent.stream()` for each event received from the remote agent. + * The `event` property contains the raw A2A SDK event data, discriminated by `kind`: + * - `'message'` — A2A Message + * - `'task'` — A2A Task + * - `'status-update'` — TaskStatusUpdateEvent + * - `'artifact-update'` — TaskArtifactUpdateEvent + */ +export class A2AStreamUpdateEvent extends StreamEvent { + readonly type = 'a2aStreamUpdateEvent' as const + readonly event: A2AEventData + + constructor(event: A2AEventData) { + super() + this.event = event + } + + toJSON(): Pick { + return { type: this.type, event: this.event } + } +} + +/** + * Event triggered as the final event in the A2A agent stream. + * Wraps the agent result containing the stop reason and last message. + */ +export class A2AResultEvent extends StreamEvent { + readonly type = 'a2aResultEvent' as const + readonly result: AgentResultEvent['result'] + + constructor(data: Pick) { + super() + this.result = data.result + } + + toJSON(): Pick { + return { type: this.type, result: this.result } + } +} + +/** + * Union of all events yielded by `A2AAgent.stream()`. + * + * Includes raw A2A protocol events ({@link A2AStreamUpdateEvent}) and the final + * result event ({@link A2AResultEvent}). + */ +export type A2AStreamEvent = A2AStreamUpdateEvent | A2AResultEvent diff --git a/strands-ts/src/a2a/executor.ts b/strands-ts/src/a2a/executor.ts new file mode 100644 index 0000000000..0ba75f3887 --- /dev/null +++ b/strands-ts/src/a2a/executor.ts @@ -0,0 +1,165 @@ +/** + * A2A executor that bridges a Strands Agent into the A2A protocol. + * + * Implements the AgentExecutor interface from `@a2a-js/sdk/server` to allow + * a Strands Agent to handle A2A JSON-RPC requests. + */ + +import type { ExecutionEventBus, RequestContext } from '@a2a-js/sdk/server' +import type { AgentExecutor } from '@a2a-js/sdk/server' +import { A2AError } from '@a2a-js/sdk/server' +import type { InvokableAgent } from '../types/agent.js' +import { ModelStreamUpdateEvent, ContentBlockEvent } from '../hooks/events.js' +import { contentBlocksToParts, partsToContentBlocks } from './adapters.js' +import { normalizeError } from '../errors.js' +import { logger } from '../logging/logger.js' + +/** + * Bridges a Strands Agent into the A2A protocol as an AgentExecutor. + * + * Converts A2A message parts to Strands content blocks, streams the agent + * execution, and publishes text deltas as artifact updates through the A2A + * event bus. Text chunks are appended to a single artifact as they arrive, + * implementing A2A-compliant streaming behavior. + * + * ## Invocation state + * + * The executor populates the agent's `invocationState` with the incoming A2A + * {@link RequestContext} under the reserved key `a2aRequestContext`. Hooks and + * tools running inside the agent can read `event.invocationState.a2aRequestContext` + * to correlate with the A2A request (taskId, contextId, user message metadata) + * for logging, metrics, or audit. + * + * Because the A2A framework (not user code) drives `execute()`, there is no + * per-request path for the user to supply their own `invocationState`. If a + * user hook writes to the `a2aRequestContext` key, it will be overwritten on + * the next request. + * + * @example + * ```typescript + * import { Agent } from '@strands-agents/sdk' + * import { A2AExecutor } from '@strands-agents/sdk/a2a' + * + * const agent = new Agent({ model: 'my-model' }) + * const executor = new A2AExecutor(agent) + * ``` + */ +export class A2AExecutor implements AgentExecutor { + private _agent: InvokableAgent + + /** + * Creates a new A2AExecutor. + * + * @param agent - The agent to execute for incoming A2A requests + */ + constructor(agent: InvokableAgent) { + this._agent = agent + } + + /** + * Executes the agent in response to an A2A message. + * + * Converts A2A message parts to Strands content blocks, then streams the + * agent execution. Text deltas are streamed incrementally into a single + * artifact; non-text content blocks (images, videos, documents) are each + * published as separate complete artifacts. A final artifact with + * `lastChunk: true` signals the end of the text artifact, followed by a + * completed status update. + * + * @param context - The A2A request context containing the user message + * @param eventBus - The event bus for publishing A2A artifact and status events + */ + async execute(context: RequestContext, eventBus: ExecutionEventBus): Promise { + const { taskId, contextId, userMessage } = context + const contentBlocks = partsToContentBlocks(userMessage.parts) + if (contentBlocks.length === 0) { + throw A2AError.invalidRequest('No content blocks available') + } + + // Publish initial task event to register the task with the ResultManager. + // Without this, artifact and status events are ignored as "unknown task". + eventBus.publish({ kind: 'task', id: taskId, contextId, status: { state: 'working' } }) + + const artifactId = globalThis.crypto.randomUUID() + let isFirstChunk = true + + try { + // Forward the A2A RequestContext to the agent under a reserved key so + // hooks and tools can correlate with the A2A request (taskId, contextId, + // user message metadata). + const stream = this._agent.stream(contentBlocks, { + invocationState: { a2aRequestContext: context }, + }) + let next = await stream.next() + + while (!next.done) { + const event = next.value + + // Stream text deltas incrementally into the text artifact + if ( + event instanceof ModelStreamUpdateEvent && + event.event.type === 'modelContentBlockDeltaEvent' && + event.event.delta.type === 'textDelta' + ) { + eventBus.publish({ + kind: 'artifact-update', + taskId, + contextId, + artifact: { + artifactId, + parts: [{ kind: 'text', text: event.event.delta.text }], + }, + append: !isFirstChunk, + }) + isFirstChunk = false + } + + // Publish non-text content blocks (images, videos, documents) as separate artifacts + if (event instanceof ContentBlockEvent && event.contentBlock.type !== 'textBlock') { + const parts = contentBlocksToParts([event.contentBlock]) + if (parts.length > 0) { + eventBus.publish({ + kind: 'artifact-update', + taskId, + contextId, + artifact: { artifactId: globalThis.crypto.randomUUID(), parts }, + append: false, + lastChunk: true, + }) + } + } + + next = await stream.next() + } + + // Publish final artifact chunk to signal end of artifact + eventBus.publish({ + kind: 'artifact-update', + taskId, + contextId, + artifact: { + artifactId, + // If no deltas were streamed, publish the full result; otherwise empty to close the artifact + parts: [{ kind: 'text', text: isFirstChunk && next.value ? next.value.toString() : '' }], + }, + append: !isFirstChunk, // false for new artifact, true to append to streamed chunks + lastChunk: true, // Always true — this runs after the stream loop ends + }) + + eventBus.publish({ kind: 'status-update', taskId, contextId, status: { state: 'completed' }, final: true }) + } catch (error) { + logger.error(`task_id=<${taskId}> | error in streaming execution`, normalizeError(error)) + throw error + } + } + + /** + * Cancels a running task. Not supported by this executor. + * + * @param taskId - The ID of the task to cancel + * @param eventBus - The event bus for publishing status events + */ + async cancelTask(_taskId: string, _eventBus: ExecutionEventBus): Promise { + throw A2AError.unsupportedOperation('Task cancellation is not supported') + } +} diff --git a/strands-ts/src/a2a/express-server.ts b/strands-ts/src/a2a/express-server.ts new file mode 100644 index 0000000000..20a1782733 --- /dev/null +++ b/strands-ts/src/a2a/express-server.ts @@ -0,0 +1,128 @@ +/** + * Express-based A2A server that exposes a Strands Agent as an A2A-compliant HTTP endpoint. + * + * Separated from the base {@link A2AServer} so that importing the core A2A module + * does not pull in Express as a dependency, keeping it browser-compatible. + * + * The A2A protocol is experimental, so breaking changes in the underlying SDK + * may require breaking changes in this module. + */ + +import express, { type Router } from 'express' +import { agentCardHandler, jsonRpcHandler, UserBuilder } from '@a2a-js/sdk/server/express' +import { A2AServer, type A2AServerConfig } from './server.js' +import { logExperimentalWarning } from './logging.js' +import { logger } from '../logging/logger.js' + +/** + * Configuration options for creating an A2AExpressServer. + */ +export interface A2AExpressServerConfig extends A2AServerConfig { + /** Host to bind the server to (default: '127.0.0.1') */ + host?: string + /** Port to listen on (default: 9000) */ + port?: number + /** User builder for authentication (default: no authentication) */ + userBuilder?: UserBuilder +} + +/** + * Express-based A2A server implementation. + * + * Provides two usage modes: + * - **Standalone**: Call {@link serve} to start a self-contained HTTP server. + * - **Middleware**: Call {@link createMiddleware} to get an Express Router that + * can be mounted in an existing Express application. + */ +export class A2AExpressServer extends A2AServer { + private _host: string + private _port: number + private _userBuilder: UserBuilder | undefined + + /** + * Creates a new A2AExpressServer. + * + * @param config - Configuration for the server + */ + constructor(config: A2AExpressServerConfig) { + const host = config.host ?? '127.0.0.1' + const port = config.port ?? 9000 + const httpUrl = config.httpUrl ?? `http://${host}:${port}` + + super({ ...config, httpUrl }) + + this._host = host + this._port = port + this._userBuilder = config.userBuilder + } + + /** + * Returns the port the server is configured to listen on. + * After `serve()` resolves, this reflects the actual bound port + * (useful when configured with port 0 for OS-assigned ports). + */ + get port(): number { + return this._port + } + + /** + * Creates an Express Router middleware for the A2A endpoints. + * + * Mounts: + * - `GET /.well-known/agent-card.json` — Returns the agent card + * - `POST /` — Handles A2A JSON-RPC requests + * + * @returns An Express Router with A2A endpoints mounted + */ + createMiddleware(): Router { + logExperimentalWarning() + + const router = express.Router() + + router.use('/.well-known/agent-card.json', agentCardHandler({ agentCardProvider: this._requestHandler })) + + router.use( + '/', + jsonRpcHandler({ + requestHandler: this._requestHandler, + userBuilder: this._userBuilder ?? UserBuilder.noAuthentication, + }) + ) + + return router + } + + /** + * Starts the HTTP server and begins listening for A2A requests. + * + * @param options - Optional server options + */ + async serve(options?: { signal?: AbortSignal }): Promise { + const app = express() + app.use(this.createMiddleware()) + + return new Promise((resolve, reject) => { + const server = app.listen(this._port, this._host, () => { + const addr = server.address() + if (addr && typeof addr === 'object') { + this._port = addr.port + this._agentCard.url = `http://${this._host}:${this._port}` + } + logger.info(`a2a server listening on http://${this._host}:${this._port}`) + resolve() + }) + + server.on('error', reject) + + if (options?.signal) { + options.signal.addEventListener( + 'abort', + () => { + server.close() + }, + { once: true } + ) + } + }) + } +} diff --git a/strands-ts/src/a2a/index.ts b/strands-ts/src/a2a/index.ts new file mode 100644 index 0000000000..40818e7c7e --- /dev/null +++ b/strands-ts/src/a2a/index.ts @@ -0,0 +1,15 @@ +/** + * A2A (Agent-to-Agent) protocol support for the Strands Agents SDK. + * + * This module provides server and client components for the A2A protocol, + * allowing Strands agents to communicate with other agents across platforms. + * + * @remarks + * The A2A protocol is experimental, so breaking changes in the underlying SDK + * may require breaking changes in this module. + */ + +export { A2AServer, type A2AServerConfig } from './server.js' +export { A2AAgent, type A2AAgentConfig } from './a2a-agent.js' +export { A2AStreamUpdateEvent, A2AResultEvent, type A2AEventData, type A2AStreamEvent } from './events.js' +export { A2AExecutor } from './executor.js' diff --git a/strands-ts/src/a2a/logging.ts b/strands-ts/src/a2a/logging.ts new file mode 100644 index 0000000000..5e1b94c38f --- /dev/null +++ b/strands-ts/src/a2a/logging.ts @@ -0,0 +1,19 @@ +/** + * Shared experimental warning for A2A protocol modules. + */ + +import { logger } from '../logging/logger.js' + +let _logged = false + +/** + * Logs a one-time warning that the A2A protocol is experimental. + */ +export function logExperimentalWarning(): void { + if (!_logged) { + _logged = true + logger.warn( + 'protocol= | experimental, breaking changes in the underlying sdk may require breaking changes in this module' + ) + } +} diff --git a/strands-ts/src/a2a/server.ts b/strands-ts/src/a2a/server.ts new file mode 100644 index 0000000000..eddd02263b --- /dev/null +++ b/strands-ts/src/a2a/server.ts @@ -0,0 +1,95 @@ +/** + * Base A2A server that manages agent card and request handler setup. + * + * This module is browser-compatible. For Express-based HTTP serving, + * see {@link A2AExpressServer} in `./express-server.ts`. + * + * The A2A protocol is experimental, so breaking changes in the underlying SDK + * may require breaking changes in this module. + */ + +import type { AgentCard, AgentSkill } from '@a2a-js/sdk' +import type { TaskStore, A2ARequestHandler } from '@a2a-js/sdk/server' +import { DefaultRequestHandler, InMemoryTaskStore } from '@a2a-js/sdk/server' +import type { InvokableAgent } from '../types/agent.js' +import { A2AExecutor } from './executor.js' + +/** + * Configuration options for creating an A2AServer. + */ +export interface A2AServerConfig { + /** The Strands Agent to serve via A2A protocol */ + agent: InvokableAgent + /** Human-readable name for the agent */ + name: string + /** Optional description of the agent's purpose */ + description?: string + /** Public URL override for the agent card */ + httpUrl?: string + /** Version string for the agent card (default: '0.0.1') */ + version?: string + /** Skills to advertise in the agent card */ + skills?: AgentSkill[] + /** Task store for persisting task state */ + taskStore?: TaskStore +} + +/** + * Base A2A server that manages agent card and request handler setup. + * + * Subclass this to integrate with different HTTP frameworks. For Express, + * use {@link A2AExpressServer}. + * + * @example + * ```typescript + * import { Agent } from '@strands-agents/sdk' + * import { A2AExpressServer } from '@strands-agents/sdk/a2a/express' + * + * const agent = new Agent({ model: 'my-model' }) + * const server = new A2AExpressServer({ + * agent, + * name: 'My Agent', + * description: 'An agent that helps with tasks', + * }) + * + * await server.serve() + * ``` + */ +export class A2AServer { + protected _agentCard: AgentCard + protected _requestHandler: A2ARequestHandler + + /** + * Creates a new A2AServer. + * + * @param config - Configuration for the server + */ + constructor(config: A2AServerConfig) { + const httpUrl = config.httpUrl ?? '' + + this._agentCard = { + name: config.name, + description: config.description ?? '', + version: config.version ?? '0.0.1', + protocolVersion: '0.2.0', + url: httpUrl, + defaultInputModes: ['text/plain'], + defaultOutputModes: ['text/plain'], + skills: config.skills ?? [], + capabilities: { + streaming: true, + }, + } + + const taskStore = config.taskStore ?? new InMemoryTaskStore() + const executor = new A2AExecutor(config.agent) + this._requestHandler = new DefaultRequestHandler(this._agentCard, taskStore, executor) + } + + /** + * Returns the agent card for this server. + */ + get agentCard(): AgentCard { + return this._agentCard + } +} diff --git a/strands-ts/src/agent/__tests__/agent-as-tool.invocation-state.test.ts b/strands-ts/src/agent/__tests__/agent-as-tool.invocation-state.test.ts new file mode 100644 index 0000000000..e4ee983d8c --- /dev/null +++ b/strands-ts/src/agent/__tests__/agent-as-tool.invocation-state.test.ts @@ -0,0 +1,28 @@ +import { describe, expect, it } from 'vitest' +import { Agent } from '../agent.js' +import { BeforeModelCallEvent } from '../../hooks/events.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import type { InvocationState } from '../../types/agent.js' + +describe('AgentAsTool invocationState forwarding', () => { + it('forwards outer invocationState into the wrapped agent and reflects inner mutations on outer result', async () => { + const innerModel = new MockMessageModel().addTurn({ type: 'textBlock', text: 'inner-done' }) + const inner = new Agent({ model: innerModel, name: 'inner', description: 'inner agent' }) + + let innerSawState: InvocationState | undefined + inner.addHook(BeforeModelCallEvent, (event) => { + innerSawState = event.invocationState + event.invocationState.innerTouched = true + }) + + const outerModel = new MockMessageModel() + .addTurn([{ type: 'toolUseBlock', name: 'inner', toolUseId: 'tu-1', input: { input: 'hi' } }]) + .addTurn({ type: 'textBlock', text: 'outer-done' }) + const outer = new Agent({ model: outerModel, tools: [inner.asTool()] }) + + const result = await outer.invoke('run inner', { invocationState: { userId: 'u-1' } }) + + expect(innerSawState).toEqual({ userId: 'u-1', innerTouched: true }) + expect(result.invocationState).toEqual({ userId: 'u-1', innerTouched: true }) + }) +}) diff --git a/strands-ts/src/agent/__tests__/agent-as-tool.test.ts b/strands-ts/src/agent/__tests__/agent-as-tool.test.ts new file mode 100644 index 0000000000..4be5cb4b95 --- /dev/null +++ b/strands-ts/src/agent/__tests__/agent-as-tool.test.ts @@ -0,0 +1,431 @@ +import { describe, expect, it } from 'vitest' +import { Agent } from '../agent.js' +import { AgentAsTool } from '../agent-as-tool.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { collectGenerator } from '../../__fixtures__/model-test-helpers.js' +import { createMockContext } from '../../__fixtures__/tool-helpers.js' +import { ToolValidationError } from '../../errors.js' +import { Tool, ToolStreamEvent } from '../../tools/tool.js' +import { ToolResultBlock } from '../../types/messages.js' +import { SessionManager } from '../../session/session-manager.js' +import type { SnapshotStorage } from '../../session/storage.js' + +describe('AgentAsTool', () => { + describe('properties', () => { + it('uses agent name as default tool name', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const agent = new Agent({ model, name: 'researcher' }) + const tool = new AgentAsTool({ agent }) + + expect(tool.name).toBe('researcher') + }) + + it('allows overriding the tool name', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const agent = new Agent({ model, name: 'researcher' }) + const tool = new AgentAsTool({ agent, name: 'research-tool' }) + + expect(tool.name).toBe('research-tool') + }) + + it('uses agent description as default tool description', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const agent = new Agent({ model, name: 'researcher', description: 'Finds information' }) + const tool = new AgentAsTool({ agent }) + + expect(tool.description).toBe('Finds information') + }) + + it('falls back to generic description when agent has no description', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const agent = new Agent({ model, name: 'researcher' }) + const tool = new AgentAsTool({ agent }) + + expect(tool.description).toBe('Use the researcher agent by providing a natural language input') + }) + + it('allows overriding the tool description', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const agent = new Agent({ model, name: 'researcher' }) + const tool = new AgentAsTool({ agent, description: 'Custom description' }) + + expect(tool.description).toBe('Custom description') + }) + + it('exposes the wrapped agent via getter', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const agent = new Agent({ model, name: 'researcher' }) + const tool = new AgentAsTool({ agent }) + + expect(tool.agent).toBe(agent) + }) + + it('has correct toolSpec shape', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const agent = new Agent({ model, name: 'researcher', description: 'Finds info' }) + const tool = new AgentAsTool({ agent }) + + expect(tool.toolSpec).toEqual({ + name: 'researcher', + description: 'Finds info', + inputSchema: { + type: 'object', + properties: { + input: { + type: 'string', + description: 'The natural language input to send to the agent.', + }, + }, + required: ['input'], + }, + }) + }) + }) + + describe('name validation', () => { + it('throws when registered with agent name containing spaces', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const subAgent = new Agent({ model, name: 'Strands Agent' }) + + expect(() => new Agent({ model, tools: [subAgent] })).toThrow(ToolValidationError) + }) + + it('throws when registered with explicit name containing invalid characters', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const subAgent = new Agent({ model, name: 'researcher' }) + + expect(() => new Agent({ model, tools: [subAgent.asTool({ name: 'has spaces' })] })).toThrow(ToolValidationError) + }) + + it('accepts valid name with hyphens and underscores', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const agent = new Agent({ model, name: 'my_research-agent' }) + const tool = new AgentAsTool({ agent }) + + expect(tool.name).toBe('my_research-agent') + }) + }) + + describe('stream', () => { + it('invokes the wrapped agent and returns text result', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Agent response' }) + const agent = new Agent({ model, name: 'test-agent', printer: false }) + const tool = new AgentAsTool({ agent }) + + const context = createMockContext({ + name: 'test-agent', + toolUseId: 'tool-1', + input: { input: 'Hello agent' }, + }) + + const { result } = await collectGenerator(tool.stream(context)) + + expect(result.toolUseId).toBe('tool-1') + expect(result.status).toBe('success') + expect(result.content).toHaveLength(1) + expect(result.content[0]).toEqual( + expect.objectContaining({ + type: 'textBlock', + text: 'Agent response', + }) + ) + }) + + it('yields ToolStreamEvents wrapping sub-agent events', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, name: 'test-agent', printer: false }) + const tool = new AgentAsTool({ agent }) + + const context = createMockContext({ + name: 'test-agent', + toolUseId: 'tool-1', + input: { input: 'Hi' }, + }) + + const { items } = await collectGenerator(tool.stream(context)) + + expect(items.length).toBeGreaterThan(0) + for (const item of items) { + expect(item).toBeInstanceOf(ToolStreamEvent) + } + }) + + it('unwraps toolStreamUpdateEvent by yielding inner ToolStreamEvent directly', async () => { + // Create a tool that yields ToolStreamEvents during execution. + // When the sub-agent runs this tool, the agent loop wraps each yielded + // ToolStreamEvent in a ToolStreamUpdateEvent. The AgentAsTool should + // unwrap these back to bare ToolStreamEvents instead of double-wrapping. + const streamingTool = { + name: 'streaming-tool', + description: 'A tool that yields stream events', + toolSpec: { + name: 'streaming-tool', + description: 'A tool that yields stream events', + inputSchema: { type: 'object' as const, properties: {} }, + }, + async *stream(context: any) { + yield new ToolStreamEvent({ data: 'progress-1' }) + yield new ToolStreamEvent({ data: 'progress-2' }) + return new ToolResultBlock({ + toolUseId: context.toolUse.toolUseId, + status: 'success' as const, + content: [], + }) + }, + } as Tool + + // Turn 1: model requests tool use → triggers the streaming tool + // Turn 2: model responds with text after tool result + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'streaming-tool', + toolUseId: 'sub-tool-1', + input: {}, + }) + .addTurn({ type: 'textBlock', text: 'Final response' }) + + const agent = new Agent({ model, name: 'test-agent', tools: [streamingTool], printer: false }) + const tool = new AgentAsTool({ agent }) + + const context = createMockContext({ + name: 'test-agent', + toolUseId: 'outer-tool-1', + input: { input: 'Do something' }, + }) + + const { items } = await collectGenerator(tool.stream(context)) + + // All yielded items should be ToolStreamEvent instances + for (const item of items) { + expect(item).toBeInstanceOf(ToolStreamEvent) + } + + // Find the unwrapped events from the streaming tool. + // If unwrapping works correctly, data is the original string. + // If double-wrapped, data would be a ToolStreamUpdateEvent object. + const progressEvents = items.filter((item) => item.data === 'progress-1' || item.data === 'progress-2') + + expect(progressEvents).toHaveLength(2) + expect(progressEvents[0]!.data).toBe('progress-1') + expect(progressEvents[1]!.data).toBe('progress-2') + }) + + it('returns error result on agent failure', async () => { + const model = new MockMessageModel().addTurn(new Error('Model failed')) + const agent = new Agent({ model, name: 'test-agent', printer: false }) + const tool = new AgentAsTool({ agent }) + + const context = createMockContext({ + name: 'test-agent', + toolUseId: 'tool-1', + input: { input: 'Hello' }, + }) + + const { result } = await collectGenerator(tool.stream(context)) + + expect(result.toolUseId).toBe('tool-1') + expect(result.status).toBe('error') + }) + + it('returns error result when agent is already busy', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Slow response' }) + const agent = new Agent({ model, name: 'test-agent', printer: false }) + const tool = new AgentAsTool({ agent }) + + const context1 = createMockContext({ + name: 'test-agent', + toolUseId: 'tool-1', + input: { input: 'First call' }, + }) + const context2 = createMockContext({ + name: 'test-agent', + toolUseId: 'tool-2', + input: { input: 'Second call' }, + }) + + // Start first call but don't fully consume it + const gen1 = tool.stream(context1) + await gen1.next() + + // Second call should get an error immediately + const { result } = await collectGenerator(tool.stream(context2)) + + expect(result.toolUseId).toBe('tool-2') + expect(result.status).toBe('error') + expect(result.content[0]).toEqual( + expect.objectContaining({ + type: 'textBlock', + text: expect.stringContaining('already processing'), + }) + ) + + // Clean up first generator + await collectGenerator(gen1) + }) + }) + + describe('preserveContext', () => { + it('resets agent state between invocations when false (default)', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Response' }) + const agent = new Agent({ model, name: 'test-agent', printer: false }) + const tool = new AgentAsTool({ agent }) + + const context1 = createMockContext({ + name: 'test-agent', + toolUseId: 'tool-1', + input: { input: 'Hello' }, + }) + const context2 = createMockContext({ + name: 'test-agent', + toolUseId: 'tool-2', + input: { input: 'Hello again' }, + }) + + await collectGenerator(tool.stream(context1)) + const messagesAfterFirst = agent.messages.length + + await collectGenerator(tool.stream(context2)) + const messagesAfterSecond = agent.messages.length + + // State is reset so both produce the same message count + expect(messagesAfterSecond).toBe(messagesAfterFirst) + }) + + it('preserves agent state across invocations when true', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Response' }) + const agent = new Agent({ model, name: 'test-agent', printer: false }) + const tool = new AgentAsTool({ agent, preserveContext: true }) + + const context1 = createMockContext({ + name: 'test-agent', + toolUseId: 'tool-1', + input: { input: 'Hello' }, + }) + const context2 = createMockContext({ + name: 'test-agent', + toolUseId: 'tool-2', + input: { input: 'Hello again' }, + }) + + await collectGenerator(tool.stream(context1)) + const messagesAfterFirst = agent.messages.length + + await collectGenerator(tool.stream(context2)) + const messagesAfterSecond = agent.messages.length + + // Messages should accumulate across invocations + expect(messagesAfterSecond).toBeGreaterThan(messagesAfterFirst) + }) + + it('snapshots at construction time, not first invocation', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Response' }) + const agent = new Agent({ model, name: 'test-agent', printer: false }) + const tool = new AgentAsTool({ agent }) + const messagesAtConstruction = agent.messages.length + + // Modify agent state after tool creation + await agent.invoke('Direct invocation') + expect(agent.messages.length).toBeGreaterThan(messagesAtConstruction) + + // First tool call restores to construction-time state, then runs + const context1 = createMockContext({ + name: 'test-agent', + toolUseId: 'tool-1', + input: { input: 'Hello' }, + }) + await collectGenerator(tool.stream(context1)) + const messagesAfterFirstTool = agent.messages.length + + // Second tool call should produce the same count — both reset to construction baseline + const context2 = createMockContext({ + name: 'test-agent', + toolUseId: 'tool-2', + input: { input: 'Hello again' }, + }) + await collectGenerator(tool.stream(context2)) + + expect(agent.messages.length).toBe(messagesAfterFirstTool) + }) + }) + + describe('sessionManager validation', () => { + const mockStorage: SnapshotStorage = { + saveSnapshot: async () => {}, + loadSnapshot: async () => null, + listSnapshotIds: async () => [], + deleteSession: async () => {}, + loadManifest: async () => ({ schemaVersion: '1.0', updatedAt: '12:00:00' }), + saveManifest: async () => {}, + } + + it('throws when preserveContext is false and agent has a sessionManager', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const sessionManager = new SessionManager({ storage: { snapshot: mockStorage } }) + const agent = new Agent({ model, name: 'test-agent', sessionManager }) + + expect(() => new AgentAsTool({ agent })).toThrow(/SessionManager.*conflicts with preserveContext=false/) + }) + + it('throws when preserveContext is explicitly false and agent has a sessionManager', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const sessionManager = new SessionManager({ storage: { snapshot: mockStorage } }) + const agent = new Agent({ model, name: 'test-agent', sessionManager }) + + expect(() => new AgentAsTool({ agent, preserveContext: false })).toThrow( + /SessionManager.*conflicts with preserveContext=false/ + ) + }) + + it('allows preserveContext true with sessionManager', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const sessionManager = new SessionManager({ storage: { snapshot: mockStorage } }) + const agent = new Agent({ model, name: 'test-agent', sessionManager }) + + expect(() => new AgentAsTool({ agent, preserveContext: true })).not.toThrow() + }) + }) + + describe('Agent.asTool', () => { + it('returns an AgentAsTool instance', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const agent = new Agent({ model, name: 'researcher' }) + + const tool = agent.asTool() + + expect(tool).toBeInstanceOf(AgentAsTool) + expect(tool.name).toBe('researcher') + }) + + it('passes options through to AgentAsTool', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const agent = new Agent({ model, name: 'researcher' }) + + const tool = agent.asTool({ name: 'custom-name', description: 'Custom desc' }) + + expect(tool.name).toBe('custom-name') + expect(tool.description).toBe('Custom desc') + }) + }) + + describe('Agent in ToolList', () => { + it('auto-wraps Agent as AgentAsTool when passed in tools array', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const subAgent = new Agent({ model, name: 'sub-agent', description: 'A sub agent' }) + const parentAgent = new Agent({ model, tools: [subAgent] }) + + const registeredTool = parentAgent.toolRegistry.get('sub-agent') + expect(registeredTool).toBeInstanceOf(AgentAsTool) + expect(registeredTool!.name).toBe('sub-agent') + }) + + it('auto-wraps Agent in nested tools arrays', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const subAgent = new Agent({ model, name: 'nested-agent' }) + const parentAgent = new Agent({ model, tools: [[subAgent]] }) + + const registeredTool = parentAgent.toolRegistry.get('nested-agent') + expect(registeredTool).toBeInstanceOf(AgentAsTool) + }) + }) +}) diff --git a/strands-ts/src/agent/__tests__/agent.cancel.test.ts b/strands-ts/src/agent/__tests__/agent.cancel.test.ts new file mode 100644 index 0000000000..50e78c1203 --- /dev/null +++ b/strands-ts/src/agent/__tests__/agent.cancel.test.ts @@ -0,0 +1,418 @@ +import { describe, expect, it } from 'vitest' +import { Agent } from '../agent.js' +import { AfterInvocationEvent, AfterModelCallEvent, BeforeModelCallEvent } from '../../hooks/index.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { createMockTool } from '../../__fixtures__/tool-helpers.js' +import { TextBlock, ToolResultBlock } from '../../types/messages.js' +import { tool } from '../../tools/tool-factory.js' + +describe('Agent Cancellation', () => { + describe('cancel() when idle', () => { + it('is a no-op and cancelSignal is not aborted', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + + expect(agent.cancelSignal.aborted).toBe(false) + agent.cancel() // Should not throw + expect(agent.cancelSignal.aborted).toBe(false) + expect(agent.cancelSignal.aborted).toBe(false) + }) + }) + + describe('cancel at top of loop (checkpoint A)', () => { + it('cancels immediately with already-aborted signal', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + const controller = new AbortController() + controller.abort() + + const result = await agent.invoke('Hi', { cancelSignal: controller.signal }) + + expect(result.stopReason).toBe('cancelled') + expect(result.lastMessage.content[0]).toEqual(new TextBlock('Cancelled by user')) + // User message is not appended — cancel fires before message append in the loop + expect(agent.messages).toHaveLength(1) + expect(agent.messages[0]!.role).toBe('assistant') + }) + + it('cancels at top of second cycle when tool calls cancel()', async () => { + const executedTools: string[] = [] + + let agent: Agent + const tool = createMockTool('cancelTool', () => { + executedTools.push('cancelTool') + agent.cancel() + return new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('Done')], + }) + }) + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'cancelTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Should not reach' }) + + agent = new Agent({ model, tools: [tool], printer: false }) + const result = await agent.invoke('Go') + + expect(result.stopReason).toBe('cancelled') + expect(executedTools).toEqual(['cancelTool']) + // messages: user, assistant(toolUse), user(toolResult), assistant(synthetic cancel) + expect(agent.messages).toHaveLength(4) + expect(agent.messages[3]!.content[0]).toEqual(new TextBlock('Cancelled by user')) + }) + }) + + describe('cancel during model streaming (checkpoint B)', () => { + it('cancels when signal is aborted before model processes events', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + agent.addHook(BeforeModelCallEvent, () => { + agent.cancel() + }) + + const result = await agent.invoke('Hi') + + expect(result.stopReason).toBe('cancelled') + expect(result.lastMessage.content[0]).toEqual(new TextBlock('Cancelled by user')) + }) + }) + + describe('cancel before tool execution (checkpoint C)', () => { + it('creates error results for all pending tools without executing them', async () => { + let toolExecuted = false + const tool = createMockTool('myTool', () => { + toolExecuted = true + return new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('Success')], + }) + }) + + const model = new MockMessageModel().addTurn({ + type: 'toolUseBlock', + name: 'myTool', + toolUseId: 'tool-1', + input: {}, + }) + + const agent = new Agent({ model, tools: [tool], printer: false }) + agent.addHook(AfterModelCallEvent, (event) => { + if (event.stopData?.stopReason === 'toolUse') { + agent.cancel() + } + }) + + const result = await agent.invoke('Do it') + + expect(result.stopReason).toBe('cancelled') + expect(toolExecuted).toBe(false) + + // Messages: user, assistant(toolUse), user(cancelled toolResult) + expect(agent.messages).toHaveLength(3) + const toolResultMsg = agent.messages[2]! + expect(toolResultMsg.content[0]).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'error', + content: [new TextBlock('Tool execution cancelled')], + }) + ) + }) + + it('creates error results for multiple pending tools', async () => { + const tool1 = createMockTool('tool1', () => { + return new ToolResultBlock({ toolUseId: 't1', status: 'success', content: [new TextBlock('R1')] }) + }) + const tool2 = createMockTool('tool2', () => { + return new ToolResultBlock({ toolUseId: 't2', status: 'success', content: [new TextBlock('R2')] }) + }) + + const model = new MockMessageModel().addTurn([ + { type: 'toolUseBlock', name: 'tool1', toolUseId: 't1', input: {} }, + { type: 'toolUseBlock', name: 'tool2', toolUseId: 't2', input: {} }, + ]) + + const agent = new Agent({ model, tools: [tool1, tool2], printer: false }) + agent.addHook(AfterModelCallEvent, (event) => { + if (event.stopData?.stopReason === 'toolUse') { + agent.cancel() + } + }) + + const result = await agent.invoke('Do both') + + expect(result.stopReason).toBe('cancelled') + + const toolResultMsg = agent.messages[2]! + expect(toolResultMsg.content).toHaveLength(2) + expect(toolResultMsg.content[0]).toEqual( + new ToolResultBlock({ toolUseId: 't1', status: 'error', content: [new TextBlock('Tool execution cancelled')] }) + ) + expect(toolResultMsg.content[1]).toEqual( + new ToolResultBlock({ toolUseId: 't2', status: 'error', content: [new TextBlock('Tool execution cancelled')] }) + ) + }) + }) + + describe('cancel between sequential tool executions', () => { + it('skips remaining tools after first tool calls cancel()', async () => { + const executedTools: string[] = [] + + let agent: Agent + const tool1 = createMockTool('firstTool', () => { + executedTools.push('firstTool') + agent.cancel() + return new ToolResultBlock({ toolUseId: 't1', status: 'success', content: [new TextBlock('Done')] }) + }) + const tool2 = createMockTool('secondTool', () => { + executedTools.push('secondTool') + return new ToolResultBlock({ toolUseId: 't2', status: 'success', content: [new TextBlock('Done')] }) + }) + + const model = new MockMessageModel() + .addTurn([ + { type: 'toolUseBlock', name: 'firstTool', toolUseId: 't1', input: {} }, + { type: 'toolUseBlock', name: 'secondTool', toolUseId: 't2', input: {} }, + ]) + .addTurn({ type: 'textBlock', text: 'Should not reach' }) + + agent = new Agent({ model, tools: [tool1, tool2], toolExecutor: 'sequential', printer: false }) + const result = await agent.invoke('Go') + + expect(result.stopReason).toBe('cancelled') + expect(executedTools).toEqual(['firstTool']) + + // First tool succeeded, second was cancelled + // messages: user, assistant(toolUse), user(toolResults), assistant(synthetic cancel) + expect(agent.messages).toHaveLength(4) + const toolResultMsg = agent.messages[2]! + expect(toolResultMsg.content[0]).toEqual( + new ToolResultBlock({ toolUseId: 't1', status: 'success', content: [new TextBlock('Done')] }) + ) + expect(toolResultMsg.content[1]).toEqual( + new ToolResultBlock({ toolUseId: 't2', status: 'error', content: [new TextBlock('Tool execution cancelled')] }) + ) + }) + }) + + describe('InvokeOptions.cancelSignal', () => { + it('cancels via external AbortSignal', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + const controller = new AbortController() + agent.addHook(BeforeModelCallEvent, () => { + controller.abort() + }) + + const result = await agent.invoke('Hi', { cancelSignal: controller.signal }) + + expect(result.stopReason).toBe('cancelled') + }) + }) + + describe('agent reuse after cancel', () => { + it('allows a second invocation after cancellation', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'textBlock', text: 'First' }) + .addTurn({ type: 'textBlock', text: 'Second' }) + + const agent = new Agent({ model, printer: false }) + + let hookCallCount = 0 + agent.addHook(BeforeModelCallEvent, () => { + hookCallCount++ + if (hookCallCount === 1) { + agent.cancel() + } + }) + + // First invocation: cancelled + const result1 = await agent.invoke('Hello') + expect(result1.stopReason).toBe('cancelled') + + // Second invocation: succeeds normally + const result2 = await agent.invoke('Hello again') + expect(result2.stopReason).toBe('endTurn') + expect(result2.lastMessage.content[0]).toEqual(new TextBlock('Second')) + }) + }) + + describe('cancel via stream break (for-await + break)', () => { + it('appends assistant message when stream is broken out of after cancel()', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'A long story...' }) + const agent = new Agent({ model, printer: false }) + + for await (const event of agent.stream('Write a story')) { + if (event.type === 'modelStreamUpdateEvent') { + agent.cancel() + break + } + } + + const lastMessage = agent.messages[agent.messages.length - 1]! + expect(lastMessage.role).toBe('assistant') + expect(lastMessage.content[0]).toEqual(new TextBlock('Cancelled by user')) + }) + + it('allows reuse after cancellation via stream break', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'textBlock', text: 'A long story...' }) + .addTurn({ type: 'textBlock', text: 'pineapple' }) + + const agent = new Agent({ model, printer: false }) + + // First invocation: cancel during streaming via break + for await (const event of agent.stream('Write a story')) { + if (event.type === 'modelStreamUpdateEvent') { + agent.cancel() + break + } + } + + // Second invocation: should succeed normally + const result = await agent.invoke('Say the word "pineapple"') + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.content[0]).toEqual(new TextBlock('pineapple')) + }) + }) + + describe('AfterInvocationEvent', () => { + it('still fires when invocation is cancelled', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + let afterInvocationFired = false + agent.addHook(AfterInvocationEvent, () => { + afterInvocationFired = true + }) + + agent.addHook(BeforeModelCallEvent, () => { + agent.cancel() + }) + + await agent.invoke('Hi') + expect(afterInvocationFired).toBe(true) + }) + }) + + describe('messages state invariants', () => { + it('has no orphaned toolUse blocks after cancel during streaming', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + agent.addHook(BeforeModelCallEvent, () => { + agent.cancel() + }) + + await agent.invoke('Hi') + + // Every assistant message with toolUse blocks must be followed by a user message with matching toolResults + for (let i = 0; i < agent.messages.length; i++) { + const msg = agent.messages[i]! + if (msg.role === 'assistant') { + const toolUseBlocks = msg.content.filter((b) => b.type === 'toolUseBlock') + if (toolUseBlocks.length > 0) { + const nextMsg = agent.messages[i + 1] + expect(nextMsg).toBeDefined() + expect(nextMsg!.role).toBe('user') + const toolResultBlocks = nextMsg!.content.filter((b) => b.type === 'toolResultBlock') + expect(toolResultBlocks).toHaveLength(toolUseBlocks.length) + } + } + } + }) + + it('has no orphaned toolUse blocks after cancel before tools', async () => { + const tool = createMockTool('myTool', () => { + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('Done')] }) + }) + + const model = new MockMessageModel().addTurn({ + type: 'toolUseBlock', + name: 'myTool', + toolUseId: 'tool-1', + input: {}, + }) + + const agent = new Agent({ model, tools: [tool], printer: false }) + agent.addHook(AfterModelCallEvent, (event) => { + if (event.stopData?.stopReason === 'toolUse') { + agent.cancel() + } + }) + + await agent.invoke('Do it') + + // Verify every toolUse has a matching toolResult + for (let i = 0; i < agent.messages.length; i++) { + const msg = agent.messages[i]! + if (msg.role === 'assistant') { + const toolUseBlocks = msg.content.filter((b) => b.type === 'toolUseBlock') + if (toolUseBlocks.length > 0) { + const nextMsg = agent.messages[i + 1] + expect(nextMsg).toBeDefined() + expect(nextMsg!.role).toBe('user') + const toolResultBlocks = nextMsg!.content.filter((b) => b.type === 'toolResultBlock') + expect(toolResultBlocks).toHaveLength(toolUseBlocks.length) + } + } + } + }) + }) + + describe('tool-level cancellation cooperation', () => { + it('exposes cancelSignal to tools via context.agent', async () => { + let signalSeen: AbortSignal | undefined + + const signalTool = tool({ + name: 'signalTool', + description: 'Tool that reads the cancellation signal', + callback: (_input, context) => { + signalSeen = context?.agent.cancelSignal + return 'done' + }, + }) + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'signalTool', toolUseId: 't1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [signalTool], printer: false }) + await agent.invoke('Go') + + expect(signalSeen).toBeInstanceOf(AbortSignal) + expect(signalSeen!.aborted).toBe(false) + }) + + it('signal is aborted when tool checks it after cancel()', async () => { + let signalAborted: boolean | undefined + + let agent: Agent + const checkTool = tool({ + name: 'checkTool', + description: 'Tool that cancels then checks the signal', + callback: (_input, context) => { + agent.cancel() + signalAborted = context?.agent.cancelSignal.aborted + return 'done' + }, + }) + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'checkTool', toolUseId: 't1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Should not reach' }) + + agent = new Agent({ model, tools: [checkTool], printer: false }) + const result = await agent.invoke('Go') + + expect(signalAborted).toBe(true) + expect(result.stopReason).toBe('cancelled') + }) + }) +}) diff --git a/strands-ts/src/agent/__tests__/agent.concurrent.test.ts b/strands-ts/src/agent/__tests__/agent.concurrent.test.ts new file mode 100644 index 0000000000..4938c210c2 --- /dev/null +++ b/strands-ts/src/agent/__tests__/agent.concurrent.test.ts @@ -0,0 +1,564 @@ +import { describe, expect, it } from 'vitest' +import { Agent } from '../agent.js' +import { + AfterToolCallEvent, + AfterToolsEvent, + BeforeToolCallEvent, + BeforeToolsEvent, + ToolResultEvent, + ToolStreamUpdateEvent, +} from '../../hooks/index.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { MockPlugin } from '../../__fixtures__/mock-plugin.js' +import { Message, TextBlock, ToolResultBlock } from '../../types/messages.js' +import { Tool, ToolStreamEvent, type ToolContext, type ToolStreamGenerator } from '../../tools/tool.js' +import type { ToolSpec } from '../../tools/types.js' + +/** + * A tool whose `stream()` suspends until `release()` is called. Lets tests + * drive concurrency deterministically without wall-clock sleeps. + * + * `started` resolves as soon as the agent enters the tool's `stream()`, so + * tests can await "both tools in flight" without polling. The tool also + * honors `ctx.agent.cancelSignal`: aborting the signal resolves the gate and + * marks `observations.cancelled = true`. + */ +class GatedTool extends Tool { + name: string + description: string + toolSpec: ToolSpec + + readonly started: Promise + readonly observations = { started: false, cancelled: false, completed: false } + + private _signalStarted!: () => void + private readonly _releaser: Promise + private _release!: () => void + + constructor(name: string) { + super() + this.name = name + this.description = `Gated tool ${name}` + this.toolSpec = { name, description: this.description, inputSchema: { type: 'object', properties: {} } } + this.started = new Promise((resolve) => (this._signalStarted = resolve)) + this._releaser = new Promise((resolve) => (this._release = resolve)) + } + + release(): void { + this._release() + } + + // eslint-disable-next-line require-yield + async *stream(ctx: ToolContext): ToolStreamGenerator { + this.observations.started = true + this._signalStarted() + + await new Promise((resolve) => { + void this._releaser.then(resolve) + ctx.agent.cancelSignal.addEventListener( + 'abort', + () => { + this.observations.cancelled = true + resolve() + }, + { once: true } + ) + }) + + this.observations.completed = true + return new ToolResultBlock({ + toolUseId: ctx.toolUse.toolUseId, + status: 'success', + content: [new TextBlock(`${this.name} done`)], + }) + } +} + +/** + * A streaming tool whose `emit(data)` yields a `ToolStreamEvent` and resolves + * only after the agent has fully dispatched it; `complete()` terminates the + * stream. Tests can drive exact interleaving between tools without timers. + */ +class GatedStreamingTool extends Tool { + name: string + description: string + toolSpec: ToolSpec + + private readonly _queue: { cmd: { type: 'emit'; data: unknown } | { type: 'complete' }; ack: () => void }[] = [] + private _notify: (() => void) | null = null + + constructor(name: string) { + super() + this.name = name + this.description = `Gated streaming tool ${name}` + this.toolSpec = { name, description: this.description, inputSchema: { type: 'object', properties: {} } } + } + + async emit(data: unknown): Promise { + return this._send({ type: 'emit', data }) + } + + async complete(): Promise { + return this._send({ type: 'complete' }) + } + + private _send(cmd: { type: 'emit'; data: unknown } | { type: 'complete' }): Promise { + return new Promise((ack) => { + this._queue.push({ cmd, ack }) + this._notify?.() + this._notify = null + }) + } + + async *stream(ctx: ToolContext): ToolStreamGenerator { + while (true) { + while (this._queue.length === 0) { + await new Promise((resolve) => (this._notify = resolve)) + } + const { cmd, ack } = this._queue.shift()! + if (cmd.type === 'complete') { + ack() + return new ToolResultBlock({ + toolUseId: ctx.toolUse.toolUseId, + status: 'success', + content: [new TextBlock(`${this.name} done`)], + }) + } + yield new ToolStreamEvent({ data: cmd.data }) + ack() + } + } +} + +function twoToolTurn(): MockMessageModel { + return new MockMessageModel() + .addTurn([ + { type: 'toolUseBlock', name: 'toolA', toolUseId: 'a', input: {} }, + { type: 'toolUseBlock', name: 'toolB', toolUseId: 'b', input: {} }, + ]) + .addTurn({ type: 'textBlock', text: 'Done' }) +} + +describe('Agent concurrent tool execution', () => { + it('runs tools concurrently by default', async () => { + const toolA = new GatedTool('toolA') + const toolB = new GatedTool('toolB') + const agent = new Agent({ + model: twoToolTurn(), + tools: [toolA, toolB], + // no toolExecutor — relies on the concurrent default + printer: false, + }) + + const invocation = agent.invoke('Go') + // Both tools reach their stream() before either is released — proves + // concurrency without relying on wall-clock overlap. + await Promise.all([toolA.started, toolB.started]) + expect(toolA.observations.completed).toBe(false) + expect(toolB.observations.completed).toBe(false) + toolA.release() + toolB.release() + await invocation + expect(toolA.observations.completed).toBe(true) + expect(toolB.observations.completed).toBe(true) + }) + + it('runs tools sequentially when toolExecutor is sequential', async () => { + const toolA = new GatedTool('toolA') + const toolB = new GatedTool('toolB') + const agent = new Agent({ + model: twoToolTurn(), + tools: [toolA, toolB], + toolExecutor: 'sequential', + printer: false, + }) + + const invocation = agent.invoke('Go') + await toolA.started + // B has not started — sequential executor is still blocked on A. + expect(toolB.observations.started).toBe(false) + toolA.release() + await toolB.started + toolB.release() + await invocation + }) + + it('preserves per-tool event ordering while interleaving across tools', async () => { + const toolA = new GatedStreamingTool('toolA') + const toolB = new GatedStreamingTool('toolB') + const plugin = new MockPlugin() + const agent = new Agent({ + model: twoToolTurn(), + tools: [toolA, toolB], + toolExecutor: 'concurrent', + printer: false, + plugins: [plugin], + }) + + const invocation = agent.invoke('Go') + // Drive explicit A,B,A,B,A,B interleaving. + await toolA.emit({ tool: 'toolA', step: 0 }) + await toolB.emit({ tool: 'toolB', step: 0 }) + await toolA.emit({ tool: 'toolA', step: 1 }) + await toolB.emit({ tool: 'toolB', step: 1 }) + await toolA.emit({ tool: 'toolA', step: 2 }) + await toolB.emit({ tool: 'toolB', step: 2 }) + await toolA.complete() + await toolB.complete() + await invocation + + // Reduce MockPlugin's invocations to the per-tool lifecycle events we care about. + type Entry = { kind: string; toolUseId?: string; tool?: string } + const events: Entry[] = plugin.invocations + .map((e): Entry | null => { + if (e instanceof BeforeToolCallEvent) return { kind: 'before', toolUseId: e.toolUse.toolUseId } + if (e instanceof AfterToolCallEvent) return { kind: 'after', toolUseId: e.toolUse.toolUseId } + if (e instanceof ToolResultEvent) return { kind: 'result', toolUseId: e.result.toolUseId } + if (e instanceof ToolStreamUpdateEvent) { + const data = e.event.data as { tool?: string } | undefined + return data?.tool !== undefined ? { kind: 'stream', tool: data.tool } : { kind: 'stream' } + } + return null + }) + .filter((e): e is Entry => e !== null) + + // Per-tool subsequence shape: [before, stream*, after, result]. + for (const toolUseId of ['a', 'b']) { + const subseq = events.filter( + (e) => e.toolUseId === toolUseId || (e.kind === 'stream' && e.tool === (toolUseId === 'a' ? 'toolA' : 'toolB')) + ) + const kinds = subseq.map((e) => e.kind) + expect(kinds[0]).toBe('before') + expect(kinds.slice(-2)).toEqual(['after', 'result']) + for (const k of kinds.slice(1, -2)) { + expect(k).toBe('stream') + } + } + + // Cross-tool interleaving: collapse consecutive same-tool stream events + // into runs. Strictly sequential execution produces 2 runs (A,A,A,B,B,B); + // anything > 2 means the stream alternated at least once. + const streamTools = events.filter((e) => e.kind === 'stream').map((e) => e.tool) + const runs = streamTools.reduce<(string | undefined)[]>((acc, t) => { + if (acc.length === 0 || acc[acc.length - 1] !== t) acc.push(t) + return acc + }, []) + expect(runs.length).toBeGreaterThan(2) + }) + + it('retries one tool independently from the other', async () => { + let retriesA = 0 + const toolA = new GatedTool('toolA') + const toolB = new GatedTool('toolB') + const agent = new Agent({ + model: twoToolTurn(), + tools: [toolA, toolB], + toolExecutor: 'concurrent', + printer: false, + }) + + const beforeCalls: string[] = [] + agent.addHook(BeforeToolCallEvent, (e) => void beforeCalls.push(e.toolUse.name)) + agent.addHook(AfterToolCallEvent, (e) => { + if (e.toolUse.name === 'toolA' && retriesA === 0) { + retriesA++ + e.retry = true + } + }) + + const invocation = agent.invoke('Go') + // Release both gates; on retry A re-enters with an already-resolved + // releaser and completes immediately. + await Promise.all([toolA.started, toolB.started]) + toolA.release() + toolB.release() + await invocation + + expect(beforeCalls.filter((n) => n === 'toolA')).toHaveLength(2) + expect(beforeCalls.filter((n) => n === 'toolB')).toHaveLength(1) + }) + + it('cancels all tools when BeforeToolsEvent.cancel is set (concurrent mode)', async () => { + const toolA = new GatedTool('toolA') + const toolB = new GatedTool('toolB') + const agent = new Agent({ + model: twoToolTurn(), + tools: [toolA, toolB], + toolExecutor: 'concurrent', + printer: false, + }) + + agent.addHook(BeforeToolsEvent, (e) => { + e.cancel = 'hook cancelled' + }) + + let afterMessage: Message | undefined + agent.addHook(AfterToolsEvent, (e) => { + afterMessage = e.message + }) + + await agent.invoke('Go') + + // No tool ever ran. + expect(toolA.observations.started).toBe(false) + expect(toolB.observations.started).toBe(false) + expect(afterMessage!.content).toHaveLength(2) + const r0 = afterMessage!.content[0] as ToolResultBlock + const r1 = afterMessage!.content[1] as ToolResultBlock + expect(r0.status).toBe('error') + expect(r1.status).toBe('error') + expect(r0.toolUseId).toBe('a') + expect(r1.toolUseId).toBe('b') + }) + + it('cancels all tools when agent is cancelled before launch (concurrent mode)', async () => { + const toolA = new GatedTool('toolA') + const toolB = new GatedTool('toolB') + const agent = new Agent({ + model: twoToolTurn(), + tools: [toolA, toolB], + toolExecutor: 'concurrent', + printer: false, + }) + + agent.addHook(BeforeToolsEvent, () => { + agent.cancel() + }) + + await agent.invoke('Go') + expect(toolA.observations.started).toBe(false) + expect(toolB.observations.started).toBe(false) + }) + + it('cooperative mid-flight cancel — tools honor cancelSignal and exit', async () => { + const toolA = new GatedTool('toolA') + const toolB = new GatedTool('toolB') + const agent = new Agent({ + model: twoToolTurn(), + tools: [toolA, toolB], + toolExecutor: 'concurrent', + printer: false, + }) + + // Cancel deterministically once both tools have entered their gates. + void Promise.all([toolA.started, toolB.started]).then(() => agent.cancel()) + + await agent.invoke('Go') + + expect(toolA.observations.cancelled).toBe(true) + expect(toolB.observations.cancelled).toBe(true) + }) + + it('handles a throwing tool without affecting siblings', async () => { + const toolA = new GatedTool('toolA') + const toolB = new GatedTool('toolB') + const agent = new Agent({ + model: twoToolTurn(), + tools: [toolA, toolB], + toolExecutor: 'concurrent', + printer: false, + }) + + // A throwing tool.stream is caught by executeTool's own try/catch and + // normalized to an error ToolResultBlock, so the race loop never sees the + // rejection. This test verifies that normalization path keeps the sibling + // unaffected in concurrent mode. The race loop's `kind: 'throw'` fallback + // is a defensive backstop for generator-level rejections that escape + // executeTool entirely — not expected in normal operation and not exercised + // here. + const results: ToolResultBlock[] = [] + agent.addHook(AfterToolsEvent, (e) => { + for (const b of e.message.content) { + if (b.type === 'toolResultBlock') results.push(b) + } + }) + + // eslint-disable-next-line require-yield + toolA.stream = async function* () { + throw new Error('boom') + } + + const invocation = agent.invoke('Go') + await toolB.started + toolB.release() + await invocation + + const [a, b] = results.sort((x, y) => x.toolUseId.localeCompare(y.toolUseId)) + expect(a!.status).toBe('error') + expect(b!.status).toBe('success') + }) + + it('handles a hallucinated tool name in a batch without affecting siblings', async () => { + const toolA = new GatedTool('toolA') + const agent = new Agent({ + model: new MockMessageModel() + .addTurn([ + { type: 'toolUseBlock', name: 'toolA', toolUseId: 'a', input: {} }, + { type: 'toolUseBlock', name: 'unknownTool', toolUseId: 'b', input: {} }, + ]) + .addTurn({ type: 'textBlock', text: 'Done' }), + tools: [toolA], + toolExecutor: 'concurrent', + printer: false, + }) + + let afterMessage: Message | undefined + agent.addHook(AfterToolsEvent, (e) => { + afterMessage = e.message + }) + + const invocation = agent.invoke('Go') + await toolA.started + toolA.release() + await invocation + + expect(afterMessage!.content).toHaveLength(2) + const blocks = afterMessage!.content as ToolResultBlock[] + expect(blocks.find((r) => r.toolUseId === 'a')!.status).toBe('success') + expect(blocks.find((r) => r.toolUseId === 'b')!.status).toBe('error') + }) + + it('preserves source order of tool results in AfterToolsEvent.message', async () => { + const toolA = new GatedTool('toolA') + const toolB = new GatedTool('toolB') + const agent = new Agent({ + model: twoToolTurn(), + tools: [toolA, toolB], + toolExecutor: 'concurrent', + printer: false, + }) + + // Deterministically complete B before A. + let resolveBDone: () => void = () => {} + const bDone = new Promise((resolve) => (resolveBDone = resolve)) + agent.addHook(ToolResultEvent, (e) => { + if (e.result.toolUseId === 'b') resolveBDone() + }) + let afterMessage: Message | undefined + agent.addHook(AfterToolsEvent, (e) => { + afterMessage = e.message + }) + + const invocation = agent.invoke('Go') + await Promise.all([toolA.started, toolB.started]) + toolB.release() + await bDone + toolA.release() + await invocation + + const blocks = afterMessage!.content as ToolResultBlock[] + expect(blocks.map((b) => b.toolUseId)).toEqual(['a', 'b']) + }) + + it('AfterToolsEvent.message contains completed results when consumer breaks mid-stream', async () => { + const toolA = new GatedTool('toolA') + const toolB = new GatedTool('toolB') // never released + const agent = new Agent({ + model: twoToolTurn(), + tools: [toolA, toolB], + toolExecutor: 'concurrent', + printer: false, + }) + + agent.addHook(BeforeToolCallEvent, (e) => { + if (e.toolUse.name === 'toolA') toolA.release() + }) + + let afterToolsMessage: Message | undefined + agent.addHook(AfterToolsEvent, (e) => { + afterToolsMessage = e.message + }) + + let toolResultsSeen = 0 + for await (const event of agent.stream('Go')) { + if (event.type === 'toolResultEvent') { + toolResultsSeen++ + if (toolResultsSeen === 1) { + // Cancel so toolB (still parked on its gate) observes cancelSignal + // and exits cooperatively — otherwise gen.return() stays blocked on + // a suspended await. + agent.cancel() + break + } + } + } + + expect(afterToolsMessage).toBeDefined() + const blocks = afterToolsMessage!.content.filter((b): b is ToolResultBlock => b.type === 'toolResultBlock') + expect(blocks.length).toBeGreaterThanOrEqual(1) + expect(blocks.some((b) => b.toolUseId === 'a')).toBe(true) + }) + + it('pre-launch agent.cancel() during BeforeToolsEvent produces "Tool execution cancelled" (concurrent)', async () => { + const toolA = new GatedTool('toolA') + const toolB = new GatedTool('toolB') + const agent = new Agent({ + model: twoToolTurn(), + tools: [toolA, toolB], + toolExecutor: 'concurrent', + printer: false, + }) + + agent.addHook(BeforeToolsEvent, () => { + agent.cancel() + }) + + let afterMessage: Message | undefined + agent.addHook(AfterToolsEvent, (e) => { + afterMessage = e.message + }) + + await agent.invoke('Go') + + expect(toolA.observations.started).toBe(false) + expect(toolB.observations.started).toBe(false) + const blocks = afterMessage!.content as ToolResultBlock[] + expect(blocks).toHaveLength(2) + for (const b of blocks) { + expect((b.content[0] as TextBlock).text).toBe('Tool execution cancelled') + } + }) + + it('closes in-flight generators and includes fallback results when consumer breaks', async () => { + const toolA = new GatedTool('toolA') + const toolB = new GatedTool('toolB') // never released + const agent = new Agent({ + model: twoToolTurn(), + tools: [toolA, toolB], + toolExecutor: 'concurrent', + printer: false, + }) + + agent.addHook(BeforeToolCallEvent, (e) => { + if (e.toolUse.name === 'toolA') toolA.release() + }) + + let afterToolsMessage: Message | undefined + agent.addHook(AfterToolsEvent, (e) => { + afterToolsMessage = e.message + }) + + let toolResultsSeen = 0 + for await (const event of agent.stream('Go')) { + if (event.type === 'toolResultEvent') { + toolResultsSeen++ + if (toolResultsSeen === 1) { + // Cancel so toolB (still parked on its gate) observes cancelSignal + // and exits cooperatively — otherwise gen.return() stays blocked on + // a suspended await. + agent.cancel() + break + } + } + } + + // AfterToolsEvent.message should have entries for both tools: + // toolA completed normally, toolB gets a fallback "interrupted" result. + expect(afterToolsMessage).toBeDefined() + const blocks = afterToolsMessage!.content as ToolResultBlock[] + expect(blocks).toHaveLength(2) + expect(blocks.map((b) => b.toolUseId)).toEqual(['a', 'b']) + expect(blocks.find((b) => b.toolUseId === 'a')!.status).toBe('success') + expect(blocks.find((b) => b.toolUseId === 'b')!.status).toBe('error') + }) +}) diff --git a/strands-ts/src/agent/__tests__/agent.hook.test.ts b/strands-ts/src/agent/__tests__/agent.hook.test.ts new file mode 100644 index 0000000000..5f56659209 --- /dev/null +++ b/strands-ts/src/agent/__tests__/agent.hook.test.ts @@ -0,0 +1,1689 @@ +import { beforeEach, describe, expect, it } from 'vitest' +import { Agent } from '../agent.js' +import { + AfterInvocationEvent, + AfterModelCallEvent, + AfterToolCallEvent, + AfterToolsEvent, + AgentResultEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, + BeforeToolCallEvent, + BeforeToolsEvent, + MessageAddedEvent, + ModelStreamUpdateEvent, + InitializedEvent, + HookableEvent, + ModelMessageEvent, +} from '../../hooks/index.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { MockPlugin } from '../../__fixtures__/mock-plugin.js' +import { collectIterator } from '../../__fixtures__/model-test-helpers.js' +import { createMockTool } from '../../__fixtures__/tool-helpers.js' +import { expectAgentResult } from '../../__fixtures__/agent-helpers.js' +import { Message, TextBlock, ToolResultBlock } from '../../types/messages.js' +import type { Plugin } from '../../plugins/plugin.js' +import type { LocalAgent } from '../../types/agent.js' +import type { Tool } from '../../tools/tool.js' + +describe('Agent Hooks Integration', () => { + let mockPlugin: MockPlugin + + beforeEach(() => { + mockPlugin = new MockPlugin() + }) + + describe('invocation lifecycle', () => { + it('fires hooks during invoke', async () => { + const lifecyclePlugin = new MockPlugin() + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, plugins: [lifecyclePlugin] }) + + await agent.invoke('Hi') + + expect(lifecyclePlugin.invocations).toHaveLength(7) + + expect(lifecyclePlugin.invocations[0]).toEqual(new InitializedEvent({ agent })) + expect(lifecyclePlugin.invocations[1]).toEqual(new BeforeInvocationEvent({ agent, invocationState: {} })) + expect(lifecyclePlugin.invocations[2]).toEqual( + new MessageAddedEvent({ + agent, + message: new Message({ role: 'user', content: [new TextBlock('Hi')] }), + invocationState: {}, + }) + ) + expect(lifecyclePlugin.invocations[3]).toEqual( + new BeforeModelCallEvent({ + agent, + model: agent.model, + invocationState: {}, + projectedInputTokens: expect.any(Number) as number, + }) + ) + expect(lifecyclePlugin.invocations[4]).toEqual( + new AfterModelCallEvent({ + agent, + model: agent.model, + invocationState: {}, + attemptCount: 1, + stopData: { + stopReason: 'endTurn', + message: new Message({ role: 'assistant', content: [new TextBlock('Hello')] }), + }, + }) + ) + expect(lifecyclePlugin.invocations[5]).toEqual( + new MessageAddedEvent({ + agent, + message: new Message({ role: 'assistant', content: [new TextBlock('Hello')] }), + invocationState: {}, + }) + ) + expect(lifecyclePlugin.invocations[6]).toEqual(new AfterInvocationEvent({ agent, invocationState: {} })) + }) + + it('fires hooks during stream', async () => { + const lifecyclePlugin = new MockPlugin() + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, plugins: [lifecyclePlugin] }) + + await collectIterator(agent.stream('Hi')) + + expect(lifecyclePlugin.invocations).toHaveLength(7) + + expect(lifecyclePlugin.invocations[0]).toEqual(new InitializedEvent({ agent })) + expect(lifecyclePlugin.invocations[1]).toEqual(new BeforeInvocationEvent({ agent, invocationState: {} })) + expect(lifecyclePlugin.invocations[2]).toEqual( + new MessageAddedEvent({ + agent, + message: new Message({ role: 'user', content: [new TextBlock('Hi')] }), + invocationState: {}, + }) + ) + expect(lifecyclePlugin.invocations[3]).toEqual( + new BeforeModelCallEvent({ + agent, + model: agent.model, + invocationState: {}, + projectedInputTokens: expect.any(Number) as number, + }) + ) + expect(lifecyclePlugin.invocations[4]).toEqual( + new AfterModelCallEvent({ + agent, + model: agent.model, + invocationState: {}, + attemptCount: 1, + stopData: { + stopReason: 'endTurn', + message: new Message({ role: 'assistant', content: [new TextBlock('Hello')] }), + }, + }) + ) + expect(lifecyclePlugin.invocations[5]).toEqual( + new MessageAddedEvent({ + agent, + message: new Message({ role: 'assistant', content: [new TextBlock('Hello')] }), + invocationState: {}, + }) + ) + expect(lifecyclePlugin.invocations[6]).toEqual(new AfterInvocationEvent({ agent, invocationState: {} })) + }) + }) + + describe('runtime hook registration', () => { + it('allows adding hooks after agent creation via addHook', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + + // Track events via individual hook registrations + const invocations: HookableEvent[] = [] + agent.addHook(BeforeInvocationEvent, (e) => { + invocations.push(e) + }) + agent.addHook(AfterInvocationEvent, (e) => { + invocations.push(e) + }) + + await agent.invoke('Hi') + + expect(invocations).toHaveLength(2) + expect(invocations[0]).toEqual(new BeforeInvocationEvent({ agent, invocationState: {} })) + expect(invocations[1]).toEqual(new AfterInvocationEvent({ agent, invocationState: {} })) + }) + }) + + describe('multi-turn conversations', () => { + it('fires hooks for each invoke call', async () => { + const lifecyclePlugin = new MockPlugin() + const model = new MockMessageModel() + .addTurn({ type: 'textBlock', text: 'First response' }) + .addTurn({ type: 'textBlock', text: 'Second response' }) + + const agent = new Agent({ model, plugins: [lifecyclePlugin] }) + + await agent.invoke('First message') + + // First turn: InitializedEvent + BeforeInvocation, MessageAdded, BeforeModelCall, AfterModelCall, MessageAdded, AfterInvocation + expect(lifecyclePlugin.invocations).toHaveLength(7) + + await agent.invoke('Second message') + + // Should have 13 events total (7 for first turn + 6 for second turn, no InitializedEvent on second) + expect(lifecyclePlugin.invocations).toHaveLength(13) + + // Filter for just Invocation events to verify they fire for each turn + const invocationEvents = lifecyclePlugin.invocations.filter( + (e) => e instanceof BeforeInvocationEvent || e instanceof AfterInvocationEvent + ) + expect(invocationEvents).toHaveLength(4) // 2 for each turn + }) + }) + + describe('tool execution hooks', () => { + it('fires tool hooks during tool execution', async () => { + const tool = createMockTool('testTool', () => { + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('Tool result')] }) + }) + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Final response' }) + + const agent = new Agent({ + model, + tools: [tool], + plugins: [mockPlugin], + }) + + await agent.invoke('Test with tool') + + // Find key events + const beforeToolCallEvents = mockPlugin.invocations.filter((e) => e instanceof BeforeToolCallEvent) + const afterToolCallEvents = mockPlugin.invocations.filter((e) => e instanceof AfterToolCallEvent) + const messageAddedEvents = mockPlugin.invocations.filter((e) => e instanceof MessageAddedEvent) + + // Verify tool hooks fired + expect(beforeToolCallEvents.length).toBe(1) + expect(afterToolCallEvents.length).toBe(1) + + // Verify 3 MessageAdded events: input message, assistant with tool use, tool result, final assistant + expect(messageAddedEvents.length).toBe(4) + + // Verify BeforeToolCallEvent + const beforeToolCall = beforeToolCallEvents[0] as BeforeToolCallEvent + expect(beforeToolCall).toEqual( + new BeforeToolCallEvent({ + agent, + toolUse: { name: 'testTool', toolUseId: 'tool-1', input: {} }, + tool, + invocationState: {}, + }) + ) + + // Verify AfterToolCallEvent + const afterToolCall = afterToolCallEvents[0] as AfterToolCallEvent + expect(afterToolCall).toEqual( + new AfterToolCallEvent({ + agent, + toolUse: { name: 'testTool', toolUseId: 'tool-1', input: {} }, + tool, + result: new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('Tool result')], + }), + invocationState: {}, + }) + ) + }) + + it('fires AfterToolCallEvent with error when tool fails', async () => { + const tool = createMockTool('failingTool', () => { + throw new Error('Tool execution failed') + }) + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'failingTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Handled error' }) + + const agent = new Agent({ + model, + tools: [tool], + plugins: [mockPlugin], + }) + + // Agent should complete successfully (tool errors are handled gracefully) + const result = await agent.invoke('Test with failing tool') + expect(result.stopReason).toBe('endTurn') + + // Find AfterToolCallEvent + const afterToolCallEvents = mockPlugin.invocations.filter((e) => e instanceof AfterToolCallEvent) + expect(afterToolCallEvents.length).toBe(1) + + const afterToolCall = afterToolCallEvents[0] as AfterToolCallEvent + expect(afterToolCall).toEqual( + new AfterToolCallEvent({ + agent, + toolUse: { name: 'failingTool', toolUseId: 'tool-1', input: {} }, + tool, + result: new ToolResultBlock({ + error: new Error('Tool execution failed'), + toolUseId: 'tool-1', + status: 'error', + content: [new TextBlock('Tool execution failed')], + }), + error: new Error('Tool execution failed'), + invocationState: {}, + }) + ) + }) + }) + + describe('ModelStreamUpdateEvent', () => { + it('is yielded in the stream and dispatched to hooks', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + + const streamUpdateEvents: ModelStreamUpdateEvent[] = [] + const agent = new Agent({ model }) + agent.addHook(ModelStreamUpdateEvent, (event: ModelStreamUpdateEvent) => { + streamUpdateEvents.push(event) + }) + + // Collect all stream events + const allStreamEvents = [] + for await (const event of agent.stream('Test')) { + allStreamEvents.push(event) + } + + // Should be yielded in the stream + const streamUpdates = allStreamEvents.filter((e) => e instanceof ModelStreamUpdateEvent) + expect(streamUpdates.length).toBeGreaterThan(0) + + // Should also fire as hook + expect(streamUpdateEvents.length).toBeGreaterThan(0) + + // Stream and hook should receive the same event instances + expect(streamUpdates).toStrictEqual(streamUpdateEvents) + }) + }) + + describe('MessageAddedEvent', () => { + it('fires for initial user input', async () => { + const initialMessage = { role: 'user' as const, content: [{ type: 'textBlock' as const, text: 'Initial' }] } + + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Response' }) + + const agent = new Agent({ + model, + messages: [initialMessage], + plugins: [mockPlugin], + }) + + await agent.invoke('New message') + + const messageAddedEvents = mockPlugin.invocations.filter((e) => e instanceof MessageAddedEvent) + + // Should have 2 MessageAdded event + expect(messageAddedEvents).toHaveLength(2) + + expect(messageAddedEvents[0]).toEqual( + new MessageAddedEvent({ + agent, + message: new Message({ role: 'user', content: [new TextBlock('New message')] }), + invocationState: {}, + }) + ) + expect(messageAddedEvents[1]).toEqual( + new MessageAddedEvent({ + agent, + message: new Message({ role: 'assistant', content: [new TextBlock('Response')] }), + invocationState: {}, + }) + ) + }) + }) + + describe('AfterModelCallEvent retry', () => { + it('does not duplicate user messages on error retry', async () => { + const model = new MockMessageModel() + .addTurn(new Error('context overflow')) + .addTurn({ type: 'textBlock', text: 'Success' }) + + const agent = new Agent({ model, printer: false }) + agent.addHook(AfterModelCallEvent, (event: AfterModelCallEvent) => { + if (event.error) { + event.retry = true + } + }) + + await agent.invoke('Hello') + + // Count user messages with "Hello" — should be exactly 1 + const userMessages = agent.messages.filter( + (m) => m.role === 'user' && m.content.some((b) => b.type === 'textBlock' && b.text === 'Hello') + ) + expect(userMessages).toHaveLength(1) + }) + + it('does not duplicate user messages on success retry', async () => { + let callCount = 0 + const model = new MockMessageModel() + .addTurn({ type: 'textBlock', text: 'First' }) + .addTurn({ type: 'textBlock', text: 'Second' }) + + const agent = new Agent({ model, printer: false }) + agent.addHook(AfterModelCallEvent, (event: AfterModelCallEvent) => { + callCount++ + if (callCount === 1 && !event.error) { + event.retry = true + } + }) + + await agent.invoke('Hello') + + const userMessages = agent.messages.filter( + (m) => m.role === 'user' && m.content.some((b) => b.type === 'textBlock' && b.text === 'Hello') + ) + expect(userMessages).toHaveLength(1) + }) + + it('retries model call when hook sets retry', async () => { + let callCount = 0 + const model = new MockMessageModel() + .addTurn(new Error('First attempt failed')) + .addTurn({ type: 'textBlock', text: 'Success after retry' }) + + const agent = new Agent({ model }) + agent.addHook(AfterModelCallEvent, (event: AfterModelCallEvent) => { + callCount++ + if (callCount === 1 && event.error) { + event.retry = true + } + }) + + const result = await agent.invoke('Test') + + expect(result.lastMessage.content[0]).toEqual({ type: 'textBlock', text: 'Success after retry' }) + expect(callCount).toBe(2) + }) + + it('does not retry when retry is not set', async () => { + const model = new MockMessageModel().addTurn(new Error('Failure')) + const agent = new Agent({ model }) + + await expect(agent.invoke('Test')).rejects.toThrow('Failure') + }) + + it('retries model call on success when hook requests it', async () => { + let callCount = 0 + const model = new MockMessageModel() + .addTurn({ type: 'textBlock', text: 'First response' }) + .addTurn({ type: 'textBlock', text: 'Second response after retry' }) + + const agent = new Agent({ model }) + agent.addHook(AfterModelCallEvent, (event: AfterModelCallEvent) => { + callCount++ + if (callCount === 1 && !event.error) { + event.retry = true + } + }) + + const result = await agent.invoke('Test') + + expect(result.lastMessage.content[0]).toEqual({ type: 'textBlock', text: 'Second response after retry' }) + expect(callCount).toBe(2) + }) + }) + + describe('AfterToolCallEvent retry', () => { + it('retries tool call when hook sets retry', async () => { + let toolCallCount = 0 + const tool = createMockTool('retryableTool', () => { + toolCallCount++ + if (toolCallCount === 1) { + throw new Error('First attempt failed') + } + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('Success')] }) + }) + + let hookCallCount = 0 + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'retryableTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [tool] }) + agent.addHook(AfterToolCallEvent, (event: AfterToolCallEvent) => { + hookCallCount++ + if (hookCallCount === 1 && event.error) { + event.retry = true + } + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('endTurn') + expect(toolCallCount).toBe(2) + expect(hookCallCount).toBe(2) + }) + + it('does not retry tool call when retry is not set', async () => { + let toolCallCount = 0 + const tool = createMockTool('failingTool', () => { + toolCallCount++ + throw new Error('Tool failed') + }) + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'failingTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Handled error' }) + + const agent = new Agent({ model, tools: [tool] }) + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('endTurn') + expect(toolCallCount).toBe(1) + }) + + it('fires BeforeToolCallEvent on each retry', async () => { + let toolCallCount = 0 + const tool = createMockTool('retryableTool', () => { + toolCallCount++ + return new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock(`Result ${toolCallCount}`)], + }) + }) + + let beforeCount = 0 + let afterCount = 0 + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'retryableTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [tool] }) + agent.addHook(BeforeToolCallEvent, () => { + beforeCount++ + }) + agent.addHook(AfterToolCallEvent, (event: AfterToolCallEvent) => { + afterCount++ + if (afterCount === 1) { + event.retry = true + } + }) + + await agent.invoke('Test') + + expect(beforeCount).toBe(2) + expect(afterCount).toBe(2) + expect(toolCallCount).toBe(2) + }) + + it('retries tool call on success when hook requests it', async () => { + let toolCallCount = 0 + const tool = createMockTool('successTool', () => { + toolCallCount++ + return new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock(`Result ${toolCallCount}`)], + }) + }) + + let hookCallCount = 0 + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'successTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [tool] }) + agent.addHook(AfterToolCallEvent, (event: AfterToolCallEvent) => { + hookCallCount++ + if (hookCallCount === 1) { + event.retry = true + } + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('endTurn') + expect(toolCallCount).toBe(2) + expect(hookCallCount).toBe(2) + }) + }) + + describe('cancel tool via hooks', () => { + it('cancels individual tool call with default message when cancel is true', async () => { + let toolExecuted = false + const tool = createMockTool('blockedTool', () => { + toolExecuted = true + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('Success')] }) + }) + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'blockedTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [tool], plugins: [mockPlugin] }) + agent.addHook(BeforeToolCallEvent, (event: BeforeToolCallEvent) => { + event.cancel = true + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('endTurn') + expect(toolExecuted).toBe(false) + + const afterToolCallEvents = mockPlugin.invocations.filter((e) => e instanceof AfterToolCallEvent) + expect(afterToolCallEvents).toHaveLength(1) + const afterEvent = afterToolCallEvents[0] as AfterToolCallEvent + expect(afterEvent.result).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'error', + content: [new TextBlock('Tool cancelled by hook')], + }) + ) + }) + + it('cancels individual tool call with custom message when cancel is a string', async () => { + let toolExecuted = false + const tool = createMockTool('blockedTool', () => { + toolExecuted = true + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('Success')] }) + }) + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'blockedTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [tool], plugins: [mockPlugin] }) + agent.addHook(BeforeToolCallEvent, (event: BeforeToolCallEvent) => { + event.cancel = 'Tool call limit exceeded' + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('endTurn') + expect(toolExecuted).toBe(false) + + const afterToolCallEvents = mockPlugin.invocations.filter((e) => e instanceof AfterToolCallEvent) + expect(afterToolCallEvents).toHaveLength(1) + const afterEvent = afterToolCallEvents[0] as AfterToolCallEvent + expect(afterEvent.result).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'error', + content: [new TextBlock('Tool call limit exceeded')], + }) + ) + }) + + it('cancels only specific tools when BeforeToolCallEvent selectively cancels', async () => { + const executedTools: string[] = [] + const tool1 = createMockTool('allowedTool', () => { + executedTools.push('allowedTool') + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('Allowed')] }) + }) + const tool2 = createMockTool('blockedTool', () => { + executedTools.push('blockedTool') + return new ToolResultBlock({ toolUseId: 'tool-2', status: 'success', content: [new TextBlock('Blocked')] }) + }) + + const model = new MockMessageModel() + .addTurn([ + { type: 'toolUseBlock', name: 'allowedTool', toolUseId: 'tool-1', input: {} }, + { type: 'toolUseBlock', name: 'blockedTool', toolUseId: 'tool-2', input: {} }, + ]) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [tool1, tool2], plugins: [mockPlugin] }) + agent.addHook(BeforeToolCallEvent, (event: BeforeToolCallEvent) => { + if (event.toolUse.name === 'blockedTool') { + event.cancel = 'This tool is blocked' + } + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('endTurn') + expect(executedTools).toEqual(['allowedTool']) + + const afterToolCallEvents = mockPlugin.invocations.filter((e) => e instanceof AfterToolCallEvent) + expect(afterToolCallEvents).toHaveLength(2) + expect((afterToolCallEvents[0] as AfterToolCallEvent).result).toEqual( + new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('Allowed')] }) + ) + expect((afterToolCallEvents[1] as AfterToolCallEvent).result).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-2', + status: 'error', + content: [new TextBlock('This tool is blocked')], + }) + ) + }) + + it('cancels all tools with default message when BeforeToolsEvent.cancel is true', async () => { + let toolExecuted = false + const tool = createMockTool('blockedTool', () => { + toolExecuted = true + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('Success')] }) + }) + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'blockedTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [tool], plugins: [mockPlugin] }) + agent.addHook(BeforeToolsEvent, (event: BeforeToolsEvent) => { + event.cancel = true + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('endTurn') + expect(toolExecuted).toBe(false) + + const afterToolsEvents = mockPlugin.invocations.filter((e) => e instanceof AfterToolsEvent) + expect(afterToolsEvents).toHaveLength(1) + const afterEvent = afterToolsEvents[0] as AfterToolsEvent + expect(afterEvent.message.content[0]).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'error', + content: [new TextBlock('Tool cancelled by hook')], + }) + ) + }) + + it('cancels all tools with custom message when BeforeToolsEvent.cancel is a string', async () => { + let toolExecuted = false + const tool = createMockTool('blockedTool', () => { + toolExecuted = true + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('Success')] }) + }) + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'blockedTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [tool], plugins: [mockPlugin] }) + agent.addHook(BeforeToolsEvent, (event: BeforeToolsEvent) => { + event.cancel = 'All tools blocked' + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('endTurn') + expect(toolExecuted).toBe(false) + + const afterToolsEvents = mockPlugin.invocations.filter((e) => e instanceof AfterToolsEvent) + expect(afterToolsEvents).toHaveLength(1) + const afterEvent = afterToolsEvents[0] as AfterToolsEvent + expect(afterEvent.message.content[0]).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'error', + content: [new TextBlock('All tools blocked')], + }) + ) + }) + + it('cancels all tools in a batch via BeforeToolsEvent with correct toolUseIds', async () => { + const executedTools: string[] = [] + const tool1 = createMockTool('tool1', () => { + executedTools.push('tool1') + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('Result 1')] }) + }) + const tool2 = createMockTool('tool2', () => { + executedTools.push('tool2') + return new ToolResultBlock({ toolUseId: 'tool-2', status: 'success', content: [new TextBlock('Result 2')] }) + }) + + const model = new MockMessageModel() + .addTurn([ + { type: 'toolUseBlock', name: 'tool1', toolUseId: 'tool-1', input: {} }, + { type: 'toolUseBlock', name: 'tool2', toolUseId: 'tool-2', input: {} }, + ]) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [tool1, tool2], plugins: [mockPlugin] }) + agent.addHook(BeforeToolsEvent, (event: BeforeToolsEvent) => { + event.cancel = 'Batch cancelled' + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('endTurn') + expect(executedTools).toEqual([]) + + const afterToolsEvents = mockPlugin.invocations.filter((e) => e instanceof AfterToolsEvent) + expect(afterToolsEvents).toHaveLength(1) + const afterEvent = afterToolsEvents[0] as AfterToolsEvent + expect(afterEvent.message.content).toEqual([ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'error', + content: [new TextBlock('Batch cancelled')], + }), + new ToolResultBlock({ + toolUseId: 'tool-2', + status: 'error', + content: [new TextBlock('Batch cancelled')], + }), + ]) + }) + + it('emits cancel events correctly via stream()', async () => { + let toolExecuted = false + const tool = createMockTool('blockedTool', () => { + toolExecuted = true + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('Success')] }) + }) + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'blockedTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [tool] }) + agent.addHook(BeforeToolCallEvent, (event: BeforeToolCallEvent) => { + event.cancel = 'Cancelled via stream' + }) + + const items = await collectIterator(agent.stream('Test')) + + expect(toolExecuted).toBe(false) + + const beforeToolCallEvents = items.filter((e) => e instanceof BeforeToolCallEvent) + const afterToolCallEvents = items.filter((e) => e instanceof AfterToolCallEvent) + expect(beforeToolCallEvents).toHaveLength(1) + expect(afterToolCallEvents).toHaveLength(1) + + const afterEvent = afterToolCallEvents[0] as AfterToolCallEvent + expect(afterEvent.result).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'error', + content: [new TextBlock('Cancelled via stream')], + }) + ) + }) + + it('allows retry after cancel on BeforeToolCallEvent', async () => { + let toolCallCount = 0 + const tool = createMockTool('retryTool', () => { + toolCallCount++ + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('Success')] }) + }) + + let beforeCount = 0 + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'retryTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [tool] }) + agent.addHook(BeforeToolCallEvent, (event: BeforeToolCallEvent) => { + beforeCount++ + if (beforeCount === 1) { + event.cancel = 'Not yet' + } + }) + agent.addHook(AfterToolCallEvent, (event: AfterToolCallEvent) => { + if (event.result.status === 'error' && beforeCount === 1) { + event.retry = true + } + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('endTurn') + expect(beforeCount).toBe(2) + expect(toolCallCount).toBe(1) // Only executed on second attempt + }) + + it('allows hooks to replace result on AfterToolCallEvent', async () => { + const tool = createMockTool('myTool', () => { + return new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('original result')], + }) + }) + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'myTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [tool] }) + agent.addHook(AfterToolCallEvent, (event: AfterToolCallEvent) => { + event.result = new ToolResultBlock({ + toolUseId: event.result.toolUseId, + status: 'success', + content: [new TextBlock('replaced result')], + }) + }) + + await agent.invoke('Test') + + const toolResultMessage = agent.messages.find( + (m) => m.role === 'user' && m.content.some((b) => b.type === 'toolResultBlock') + ) + const toolResultBlock = toolResultMessage!.content.find((b): b is ToolResultBlock => b.type === 'toolResultBlock') + expect(toolResultBlock).toStrictEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('replaced result')], + }) + ) + }) + }) + + describe('AfterToolsEvent.endTurn', () => { + const makeSingleToolSetup = (): { tool: Tool; model: MockMessageModel } => ({ + tool: createMockTool('myTool', () => { + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('result')] }) + }), + model: new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'myTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Should not reach this' }), + }) + + it('halts the loop when endTurn is true with default message', async () => { + const { tool, model } = makeSingleToolSetup() + const agent = new Agent({ model, tools: [tool], plugins: [mockPlugin] }) + agent.addHook(AfterToolsEvent, (event: AfterToolsEvent) => { + event.endTurn = true + }) + + const result = await agent.invoke('Test') + + expect(result).toEqual( + expect.objectContaining({ + type: 'agentResult', + stopReason: 'endTurn', + lastMessage: expect.objectContaining({ + role: 'assistant', + content: expect.arrayContaining([ + expect.objectContaining({ type: 'textBlock', text: 'Turn ended early by hook after tool execution' }), + ]), + }), + }) + ) + expect(model.callCount).toBe(1) + }) + + it('halts the loop with custom assistant message when endTurn is a string', async () => { + const { tool, model } = makeSingleToolSetup() + const agent = new Agent({ model, tools: [tool] }) + agent.addHook(AfterToolsEvent, (event: AfterToolsEvent) => { + event.endTurn = 'enough information gathered' + }) + + const result = await agent.invoke('Test') + + expect(result).toEqual( + expect.objectContaining({ + type: 'agentResult', + stopReason: 'endTurn', + lastMessage: expect.objectContaining({ + role: 'assistant', + content: expect.arrayContaining([ + expect.objectContaining({ type: 'textBlock', text: 'enough information gathered' }), + ]), + }), + }) + ) + expect(model.callCount).toBe(1) + }) + + it('does not halt when endTurn is false (default)', async () => { + const { tool, model } = makeSingleToolSetup() + const agent = new Agent({ model, tools: [tool] }) + + const result = await agent.invoke('Test') + + expect(result).toEqual( + expect.objectContaining({ + type: 'agentResult', + stopReason: 'endTurn', + lastMessage: expect.objectContaining({ role: 'assistant' }), + }) + ) + expect(model.callCount).toBe(2) + }) + + it('treats empty string endTurn as falsy (does not halt)', async () => { + const { tool, model } = makeSingleToolSetup() + const agent = new Agent({ model, tools: [tool] }) + agent.addHook(AfterToolsEvent, (event: AfterToolsEvent) => { + event.endTurn = '' + }) + + const result = await agent.invoke('Test') + + expect(result).toEqual( + expect.objectContaining({ + type: 'agentResult', + stopReason: 'endTurn', + lastMessage: expect.objectContaining({ role: 'assistant' }), + }) + ) + expect(model.callCount).toBe(2) + }) + + it('appends tool results and default endTurn message to conversation history', async () => { + const { tool, model } = makeSingleToolSetup() + const agent = new Agent({ model, tools: [tool] }) + agent.addHook(AfterToolsEvent, (event: AfterToolsEvent) => { + event.endTurn = true + }) + + await agent.invoke('Test') + + expect(agent.messages).toHaveLength(4) + + expect(agent.messages[0]!.role).toBe('user') + expect(agent.messages[1]!.role).toBe('assistant') + expect(agent.messages[1]!.content).toEqual( + expect.arrayContaining([expect.objectContaining({ type: 'toolUseBlock' })]) + ) + expect(agent.messages[2]!.role).toBe('user') + expect(agent.messages[2]!.content).toEqual( + expect.arrayContaining([expect.objectContaining({ type: 'toolResultBlock' })]) + ) + expect(agent.messages[3]!.role).toBe('assistant') + expect(agent.messages[3]!.content).toEqual( + expect.arrayContaining([ + expect.objectContaining({ type: 'textBlock', text: 'Turn ended early by hook after tool execution' }), + ]) + ) + }) + + it('halts the loop with concurrent tool execution', async () => { + const tool1 = createMockTool('tool1', () => { + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('Result 1')] }) + }) + const tool2 = createMockTool('tool2', () => { + return new ToolResultBlock({ toolUseId: 'tool-2', status: 'success', content: [new TextBlock('Result 2')] }) + }) + + const model = new MockMessageModel() + .addTurn([ + { type: 'toolUseBlock', name: 'tool1', toolUseId: 'tool-1', input: {} }, + { type: 'toolUseBlock', name: 'tool2', toolUseId: 'tool-2', input: {} }, + ]) + .addTurn({ type: 'textBlock', text: 'Should not reach this' }) + + const agent = new Agent({ model, tools: [tool1, tool2], toolExecutor: 'concurrent' }) + agent.addHook(AfterToolsEvent, (event: AfterToolsEvent) => { + event.endTurn = true + }) + + const result = await agent.invoke('Test') + + expect(result).toEqual( + expect.objectContaining({ + type: 'agentResult', + stopReason: 'endTurn', + lastMessage: expect.objectContaining({ role: 'assistant' }), + }) + ) + expect(model.callCount).toBe(1) + }) + + it('emits AfterToolsEvent with endTurn via stream()', async () => { + const { tool, model } = makeSingleToolSetup() + const agent = new Agent({ model, tools: [tool] }) + agent.addHook(AfterToolsEvent, (event: AfterToolsEvent) => { + event.endTurn = true + }) + + const items = await collectIterator(agent.stream('Test')) + + const afterToolsEvents = items.filter((e) => e instanceof AfterToolsEvent) + expect(afterToolsEvents).toHaveLength(1) + expect((afterToolsEvents[0] as AfterToolsEvent).endTurn).toBe(true) + + const resultEvents = items.filter((e) => e instanceof AgentResultEvent) + expect(resultEvents).toHaveLength(1) + expect((resultEvents[0] as AgentResultEvent).result.stopReason).toBe('endTurn') + }) + + it('halts even when set on a cancelled-tools AfterToolsEvent', async () => { + const { tool, model } = makeSingleToolSetup() + const agent = new Agent({ model, tools: [tool] }) + agent.addHook(BeforeToolsEvent, (event: BeforeToolsEvent) => { + event.cancel = true + }) + agent.addHook(AfterToolsEvent, (event: AfterToolsEvent) => { + event.endTurn = true + }) + + const result = await agent.invoke('Test') + + expect(result).toEqual( + expect.objectContaining({ + type: 'agentResult', + stopReason: 'endTurn', + lastMessage: expect.objectContaining({ role: 'assistant' }), + }) + ) + expect(model.callCount).toBe(1) + }) + }) + + describe('cancel invocation via hooks', () => { + it('cancels invocation with default message when cancel is true', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, plugins: [mockPlugin] }) + agent.addHook(BeforeInvocationEvent, (event: BeforeInvocationEvent) => { + event.cancel = true + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.content[0]).toEqual(new TextBlock('invocation denied by hook')) + + const beforeModelCallEvents = mockPlugin.invocations.filter((e) => e instanceof BeforeModelCallEvent) + expect(beforeModelCallEvents).toHaveLength(0) + }) + + it('cancels invocation with custom message when cancel is a string', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, plugins: [mockPlugin] }) + agent.addHook(BeforeInvocationEvent, (event: BeforeInvocationEvent) => { + event.cancel = 'Unauthorized user' + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.content[0]).toEqual(new TextBlock('Unauthorized user')) + }) + + it('does not append user message when invocation is cancelled', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + agent.addHook(BeforeInvocationEvent, (event: BeforeInvocationEvent) => { + event.cancel = true + }) + + await agent.invoke('Test') + + expect(agent.messages).toHaveLength(1) + expect(agent.messages[0]!.role).toBe('assistant') + }) + + it('emits AfterInvocationEvent when invocation is cancelled', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, plugins: [mockPlugin] }) + agent.addHook(BeforeInvocationEvent, (event: BeforeInvocationEvent) => { + event.cancel = true + }) + + await agent.invoke('Test') + + const beforeInvocationEvents = mockPlugin.invocations.filter((e) => e instanceof BeforeInvocationEvent) + const afterInvocationEvents = mockPlugin.invocations.filter((e) => e instanceof AfterInvocationEvent) + expect(beforeInvocationEvents).toHaveLength(1) + expect(afterInvocationEvents).toHaveLength(1) + }) + }) + + describe('cancel model call via hooks', () => { + it('cancels model call with default message when cancel is true', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, plugins: [mockPlugin] }) + agent.addHook(BeforeModelCallEvent, (event: BeforeModelCallEvent) => { + event.cancel = true + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.content[0]).toEqual(new TextBlock('model call denied by hook')) + }) + + it('cancels model call with custom message when cancel is a string', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, plugins: [mockPlugin] }) + agent.addHook(BeforeModelCallEvent, (event: BeforeModelCallEvent) => { + event.cancel = 'Rate limited' + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.content[0]).toEqual(new TextBlock('Rate limited')) + }) + + it('emits AfterModelCallEvent when model call is cancelled', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, plugins: [mockPlugin] }) + agent.addHook(BeforeModelCallEvent, (event: BeforeModelCallEvent) => { + event.cancel = true + }) + + await agent.invoke('Test') + + const beforeModelCallEvents = mockPlugin.invocations.filter((e) => e instanceof BeforeModelCallEvent) + const afterModelCallEvents = mockPlugin.invocations.filter((e) => e instanceof AfterModelCallEvent) + expect(beforeModelCallEvents).toHaveLength(1) + expect(afterModelCallEvents).toHaveLength(1) + }) + + it('does not emit ModelMessageEvent when model call is cancelled', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, plugins: [mockPlugin] }) + agent.addHook(BeforeModelCallEvent, (event: BeforeModelCallEvent) => { + event.cancel = true + }) + + await agent.invoke('Test') + + const modelMessageEvents = mockPlugin.invocations.filter((e) => e instanceof ModelMessageEvent) + expect(modelMessageEvents).toHaveLength(0) + }) + + it('allows retry after cancel on model call', async () => { + let beforeCount = 0 + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, plugins: [mockPlugin] }) + agent.addHook(BeforeModelCallEvent, (event: BeforeModelCallEvent) => { + beforeCount++ + if (beforeCount === 1) { + event.cancel = 'Not yet' + } + }) + agent.addHook(AfterModelCallEvent, (event: AfterModelCallEvent) => { + if (beforeCount === 1) { + event.retry = true + } + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('endTurn') + expect(beforeCount).toBe(2) + expect(result.lastMessage.content[0]).toEqual(new TextBlock('Hello')) + }) + }) + + describe('BeforeToolCallEvent selectedTool', () => { + it('invokes the replacement tool instead of the registry tool', async () => { + let originalExecuted = false + let replacementExecuted = false + const originalTool = createMockTool('originalTool', () => { + originalExecuted = true + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('original')] }) + }) + const replacementTool = createMockTool('replacementTool', () => { + replacementExecuted = true + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('replacement')] }) + }) + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'originalTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [originalTool], plugins: [mockPlugin] }) + agent.addHook(BeforeToolCallEvent, (event: BeforeToolCallEvent) => { + event.selectedTool = replacementTool + }) + + await agent.invoke('Test') + + expect(originalExecuted).toBe(false) + expect(replacementExecuted).toBe(true) + + const afterToolCallEvents = mockPlugin.invocations.filter((e) => e instanceof AfterToolCallEvent) + expect(afterToolCallEvents).toHaveLength(1) + expect((afterToolCallEvents[0] as AfterToolCallEvent).result.content).toEqual([new TextBlock('replacement')]) + }) + + it('cancel wins over selectedTool', async () => { + let replacementExecuted = false + const replacementTool = createMockTool('replacementTool', () => { + replacementExecuted = true + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('replacement')] }) + }) + const registryTool = createMockTool('registryTool', () => { + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('registry')] }) + }) + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'registryTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [registryTool], plugins: [mockPlugin] }) + agent.addHook(BeforeToolCallEvent, (event: BeforeToolCallEvent) => { + event.selectedTool = replacementTool + event.cancel = 'blocked' + }) + + await agent.invoke('Test') + + expect(replacementExecuted).toBe(false) + + // AfterToolCallEvent.tool should report the selectedTool even on the cancel path, + // so observability hooks see a consistent `tool` value regardless of branch. + const afterToolCallEvents = mockPlugin.invocations.filter((e) => e instanceof AfterToolCallEvent) + expect(afterToolCallEvents).toHaveLength(1) + expect((afterToolCallEvents[0] as AfterToolCallEvent).tool).toBe(replacementTool) + }) + + it('works with concurrent tool executor', async () => { + let originalExecuted = false + let replacementExecuted = false + const originalTool = createMockTool('originalTool', () => { + originalExecuted = true + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('original')] }) + }) + const replacementTool = createMockTool('replacementTool', () => { + replacementExecuted = true + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('replacement')] }) + }) + const otherTool = createMockTool('otherTool', () => { + return new ToolResultBlock({ toolUseId: 'tool-2', status: 'success', content: [new TextBlock('other')] }) + }) + + const model = new MockMessageModel() + .addTurn([ + { type: 'toolUseBlock', name: 'originalTool', toolUseId: 'tool-1', input: {} }, + { type: 'toolUseBlock', name: 'otherTool', toolUseId: 'tool-2', input: {} }, + ]) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ + model, + tools: [originalTool, otherTool], + toolExecutor: 'concurrent', + }) + agent.addHook(BeforeToolCallEvent, (event: BeforeToolCallEvent) => { + if (event.toolUse.name === 'originalTool') { + event.selectedTool = replacementTool + } + }) + + await agent.invoke('Test') + + expect(originalExecuted).toBe(false) + expect(replacementExecuted).toBe(true) + }) + }) + + describe('BeforeToolCallEvent toolUse mutation', () => { + it('passes mutated input to the tool', async () => { + const capturedInputs: unknown[] = [] + const tool = createMockTool('tool', () => { + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('ok')] }) + }) + // Wrap to capture input via the context the tool receives. + const capturingTool = { + ...tool, + async *stream(context: Parameters[0]) { + capturedInputs.push(context.toolUse.input) + return yield* tool.stream(context) + }, + } + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'tool', toolUseId: 'tool-1', input: { a: 1 } }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [capturingTool] }) + agent.addHook(BeforeToolCallEvent, (event: BeforeToolCallEvent) => { + event.toolUse.input = { a: 2, injected: true } + }) + + await agent.invoke('Test') + + expect(capturedInputs).toEqual([{ a: 2, injected: true }]) + }) + + it('re-resolves the tool when hook renames toolUse.name', async () => { + let origExecuted = false + let renamedExecuted = false + const origTool = createMockTool('orig', () => { + origExecuted = true + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('orig')] }) + }) + const renamedTool = createMockTool('renamed', () => { + renamedExecuted = true + return new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('renamed')] }) + }) + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'orig', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [origTool, renamedTool] }) + agent.addHook(BeforeToolCallEvent, (event: BeforeToolCallEvent) => { + event.toolUse.name = 'renamed' + }) + + await agent.invoke('Test') + + expect(origExecuted).toBe(false) + expect(renamedExecuted).toBe(true) + }) + + it('works with concurrent tool executor', async () => { + const capturedInputs: Record = {} + const baseA = createMockTool('toolA', () => { + return new ToolResultBlock({ toolUseId: 'a', status: 'success', content: [new TextBlock('a done')] }) + }) + const baseB = createMockTool('toolB', () => { + return new ToolResultBlock({ toolUseId: 'b', status: 'success', content: [new TextBlock('b done')] }) + }) + const toolA = { + ...baseA, + async *stream(context: Parameters[0]) { + capturedInputs[context.toolUse.name] = context.toolUse.input + return yield* baseA.stream(context) + }, + } + const toolB = { + ...baseB, + async *stream(context: Parameters[0]) { + capturedInputs[context.toolUse.name] = context.toolUse.input + return yield* baseB.stream(context) + }, + } + + const model = new MockMessageModel() + .addTurn([ + { type: 'toolUseBlock', name: 'toolA', toolUseId: 'a', input: { original: 'a' } }, + { type: 'toolUseBlock', name: 'toolB', toolUseId: 'b', input: { original: 'b' } }, + ]) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [toolA, toolB], toolExecutor: 'concurrent' }) + agent.addHook(BeforeToolCallEvent, (event: BeforeToolCallEvent) => { + event.toolUse.input = { mutated: event.toolUse.name } + }) + + await agent.invoke('Test') + + expect(capturedInputs).toEqual({ + toolA: { mutated: 'toolA' }, + toolB: { mutated: 'toolB' }, + }) + }) + }) + + describe('AfterToolCallEvent result mutation', () => { + it('propagates mutated result into the conversation message', async () => { + const tool = createMockTool('tool', () => { + return new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('SECRET_VALUE')], + }) + }) + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'tool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [tool] }) + agent.addHook(AfterToolCallEvent, (event: AfterToolCallEvent) => { + event.result = new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('[REDACTED]')], + }) + }) + + await agent.invoke('Test') + + const toolResultMessage = agent.messages.find((m) => + m.content.some((b) => b.type === 'toolResultBlock' && b.toolUseId === 'tool-1') + ) + expect(toolResultMessage).toBeDefined() + const block = toolResultMessage!.content.find( + (b): b is ToolResultBlock => b.type === 'toolResultBlock' && b.toolUseId === 'tool-1' + ) + expect(block!.content).toEqual([new TextBlock('[REDACTED]')]) + }) + + it('propagates mutated result into AfterToolsEvent', async () => { + const tool = createMockTool('tool', () => { + return new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('SECRET_VALUE')], + }) + }) + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'tool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, tools: [tool], plugins: [mockPlugin] }) + agent.addHook(AfterToolCallEvent, (event: AfterToolCallEvent) => { + event.result = new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('[REDACTED]')], + }) + }) + + await agent.invoke('Test') + + const afterToolsEvents = mockPlugin.invocations.filter((e) => e instanceof AfterToolsEvent) + expect(afterToolsEvents).toHaveLength(1) + const block = (afterToolsEvents[0] as AfterToolsEvent).message.content.find( + (b): b is ToolResultBlock => b.type === 'toolResultBlock' && b.toolUseId === 'tool-1' + ) + expect(block!.content).toEqual([new TextBlock('[REDACTED]')]) + }) + }) + + describe('AfterInvocationEvent resume', () => { + it('re-invokes the agent with the resume args', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'textBlock', text: 'first' }) + .addTurn({ type: 'textBlock', text: 'second' }) + + let invocationCount = 0 + const agent = new Agent({ model }) + agent.addHook(AfterInvocationEvent, (event: AfterInvocationEvent) => { + invocationCount++ + if (invocationCount === 1) { + event.resume = 'follow-up' + } + }) + + const result = await agent.invoke('initial') + + expect(invocationCount).toBe(2) + expect(result).toEqual( + expectAgentResult({ + stopReason: 'endTurn', + messageText: 'second', + // Meter cycleCount is cumulative across the resume chain (1 cycle per invocation x 2). + cycleCount: 2, + }) + ) + }) + + it('chains multiple resumes', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'textBlock', text: 'a' }) + .addTurn({ type: 'textBlock', text: 'b' }) + .addTurn({ type: 'textBlock', text: 'c' }) + + let invocationCount = 0 + const agent = new Agent({ model }) + agent.addHook(AfterInvocationEvent, (event: AfterInvocationEvent) => { + invocationCount++ + if (invocationCount === 1) event.resume = 'second' + else if (invocationCount === 2) event.resume = 'third' + }) + + const result = await agent.invoke('first') + + expect(invocationCount).toBe(3) + expect(result.lastMessage.content[0]).toEqual({ type: 'textBlock', text: 'c' }) + }) + + it('does not resume when resume is left undefined', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'only' }) + + let invocationCount = 0 + const agent = new Agent({ model }) + agent.addHook(AfterInvocationEvent, () => { + invocationCount++ + }) + + await agent.invoke('hi') + + expect(invocationCount).toBe(1) + }) + + it('does not resume when the invocation errors', async () => { + const model = new MockMessageModel().addTurn(new Error('boom')) + + let invocationCount = 0 + const agent = new Agent({ model }) + agent.addHook(AfterInvocationEvent, (event: AfterInvocationEvent) => { + invocationCount++ + event.resume = 'should-not-run' + }) + + await expect(agent.invoke('hi')).rejects.toThrow('boom') + expect(invocationCount).toBe(1) + }) + + it('first-registered hook wins when multiple hooks set resume', async () => { + // AfterInvocationEvent reverses callback order (_shouldReverseCallbacks=true), + // so the first-registered hook fires last and its resume value wins. + const model = new MockMessageModel() + .addTurn({ type: 'textBlock', text: 'first' }) + .addTurn({ type: 'textBlock', text: 'second' }) + + let invocationCount = 0 + const agent = new Agent({ model }) + agent.addHook(BeforeInvocationEvent, () => { + invocationCount++ + }) + agent.addHook(AfterInvocationEvent, (event: AfterInvocationEvent) => { + if (invocationCount === 1) event.resume = 'first-registered wins' + }) + agent.addHook(AfterInvocationEvent, (event: AfterInvocationEvent) => { + if (invocationCount === 1) event.resume = 'second-registered loses' + }) + + await agent.invoke('initial') + + const userTexts = agent.messages + .filter((m) => m.role === 'user') + .flatMap((m) => m.content.filter((b): b is TextBlock => b.type === 'textBlock').map((b) => b.text)) + expect(userTexts).toEqual(['initial', 'first-registered wins']) + }) + + it('ignores resume set during an erroring invocation', async () => { + // Resume should not fire when the invocation ends with an error, even if + // AfterInvocationEvent (which fires in _stream's finally) still runs. + const model = new MockMessageModel().addTurn(new Error('boom')) + + let resumeFired = false + const agent = new Agent({ model }) + agent.addHook(AfterInvocationEvent, (event: AfterInvocationEvent) => { + event.resume = 'should not run' + }) + agent.addHook(BeforeInvocationEvent, () => { + // Track whether BeforeInvocationEvent fires a second time (would indicate resume ran). + if (resumeFired) throw new Error('unexpected second invocation') + resumeFired = true + }) + + await expect(agent.invoke('hi')).rejects.toThrow('boom') + }) + + it('emits only one AgentResultEvent for a resumed chain', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'textBlock', text: 'first' }) + .addTurn({ type: 'textBlock', text: 'second' }) + + let invocationCount = 0 + const agent = new Agent({ model }) + agent.addHook(AfterInvocationEvent, (event: AfterInvocationEvent) => { + invocationCount++ + if (invocationCount === 1) { + event.resume = 'follow-up' + } + }) + + const items = await collectIterator(agent.stream('initial')) + + const agentResults = items.filter((e) => e instanceof AgentResultEvent) + expect(agentResults).toHaveLength(1) + const afterInvocations = items.filter((e) => e instanceof AfterInvocationEvent) + expect(afterInvocations).toHaveLength(2) + }) + }) + + describe('queue-based lifecycle plugin (WASM bridge pattern)', () => { + function createLifecycleBridgePlugin(queue: string[]): Plugin { + return { + name: 'strands:lifecycle-bridge', + initAgent(agent: LocalAgent): void { + agent.addHook(InitializedEvent, () => { + queue.push('initialized') + }) + agent.addHook(BeforeInvocationEvent, () => { + queue.push('before-invocation') + }) + agent.addHook(AfterInvocationEvent, () => { + queue.push('after-invocation') + }) + agent.addHook(BeforeModelCallEvent, () => { + queue.push('before-model-call') + }) + agent.addHook(AfterModelCallEvent, () => { + queue.push('after-model-call') + }) + agent.addHook(MessageAddedEvent, () => { + queue.push('message-added') + }) + agent.addHook(BeforeToolCallEvent, () => { + queue.push('before-tool-call') + }) + agent.addHook(AfterToolCallEvent, () => { + queue.push('after-tool-call') + }) + }, + } + } + + it('receives lifecycle events when registered via plugins config', async () => { + const queue: string[] = [] + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, plugins: [createLifecycleBridgePlugin(queue)] }) + await agent.invoke('Hi') + + expect(queue).toStrictEqual([ + 'initialized', + 'before-invocation', + 'message-added', + 'before-model-call', + 'after-model-call', + 'message-added', + 'after-invocation', + ]) + }) + + it('receives no events when passed via non-existent hooks config field', async () => { + const queue: string[] = [] + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, hooks: [createLifecycleBridgePlugin(queue)] } as any) + await agent.invoke('Hi') + + expect(queue).toHaveLength(0) + }) + }) +}) diff --git a/strands-ts/src/agent/__tests__/agent.interrupt.test.ts b/strands-ts/src/agent/__tests__/agent.interrupt.test.ts new file mode 100644 index 0000000000..75d0c54b6d --- /dev/null +++ b/strands-ts/src/agent/__tests__/agent.interrupt.test.ts @@ -0,0 +1,951 @@ +import { describe, expect, it } from 'vitest' +import { Agent } from '../agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { createMockTool } from '../../__fixtures__/tool-helpers.js' +import { ToolResultBlock } from '../../types/messages.js' +import { AfterToolCallEvent, BeforeToolCallEvent, BeforeToolsEvent, InterruptEvent } from '../../hooks/events.js' +import { FunctionTool } from '../../tools/function-tool.js' +import { InterruptResponseContent } from '../../types/interrupt.js' +import type { InterruptState, PendingToolExecution } from '../../interrupt.js' + +/** Access the agent's internal interrupt state for test assertions. */ +function getPendingToolExecution(agent: Agent): PendingToolExecution | undefined { + // yes it's dirty, but we don't want to expose this publicly + return (agent as unknown as { _interruptState: InterruptState })._interruptState.pendingToolExecution +} + +describe('Agent interrupt system', () => { + describe('interrupt from tool callback', () => { + it('returns stopReason interrupt when tool calls interrupt()', async () => { + // Model returns tool use first, then text block (following standard test pattern) + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'confirmTool', + toolUseId: 'tool-1', + input: {}, + }) + .addTurn({ type: 'textBlock', text: 'Should not reach this' }) + + const tool = createMockTool('confirmTool', (context) => { + context.interrupt({ name: 'confirm', reason: 'Please confirm' }) + }) + + const agent = new Agent({ model, tools: [tool], printer: false }) + const result = await agent.invoke('Test') + + expect(result).toMatchObject({ + stopReason: 'interrupt', + interrupts: [{ name: 'confirm', reason: 'Please confirm' }], + }) + }) + }) + + describe('interrupt from BeforeToolCallEvent hook', () => { + it('returns stopReason interrupt when hook calls interrupt()', async () => { + // Model returns tool use first, then text block (following standard test pattern) + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'testTool', + toolUseId: 'tool-1', + input: {}, + }) + .addTurn({ type: 'textBlock', text: 'Should not reach this' }) + + const tool = createMockTool('testTool', () => 'Success') + + const agent = new Agent({ model, tools: [tool], printer: false }) + + agent.addHook(BeforeToolCallEvent, (event) => { + if (event.toolUse.name === 'testTool') { + event.interrupt({ name: 'confirm_tool', reason: 'Confirm tool execution?' }) + } + }) + + const result = await agent.invoke('Test') + + expect(result).toMatchObject({ + stopReason: 'interrupt', + interrupts: [{ name: 'confirm_tool', reason: 'Confirm tool execution?' }], + }) + }) + + it('stores pending state and resumes correctly after interrupt', async () => { + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'deleteTool', + toolUseId: 'tool-1', + input: { key: 'X' }, + }) + .addTurn({ type: 'textBlock', text: 'Deleted' }) + + let toolExecuted = false + const tool = createMockTool('deleteTool', () => { + toolExecuted = true + return 'deleted' + }) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + agent.addHook(BeforeToolCallEvent, (event) => { + if (event.toolUse.name === 'deleteTool') { + const approval = event.interrupt({ name: 'approve_delete', reason: 'Confirm delete?' }) + if (approval !== 'yes') { + event.cancel = 'not approved' + } + } + }) + + // First invocation — hook interrupts before tool runs + const interruptResult = await agent.invoke('Delete X') + + expect(interruptResult.stopReason).toBe('interrupt') + expect(interruptResult.interrupts).toMatchObject([ + { id: expect.any(String), name: 'approve_delete', reason: 'Confirm delete?', source: 'hook' }, + ]) + expect(toolExecuted).toBe(false) + expect(model.callCount).toBe(1) + + // Verify pending execution state was stored (the core of pgrayy's concern: + // the InterruptError thrown back into the generator at `yield beforeToolCallEvent` + // must propagate to executeTools' catch block which stores this state) + const pendingExecution = getPendingToolExecution(agent) + expect(pendingExecution).toEqual({ + assistantMessageData: { + role: 'assistant', + content: [{ toolUse: { name: 'deleteTool', toolUseId: 'tool-1', input: { key: 'X' } } }], + }, + completedToolResults: {}, + }) + + // Resume with approval — tool should now execute + const finalResult = await agent.invoke([ + new InterruptResponseContent({ + interruptId: interruptResult.interrupts![0]!.id, + response: 'yes', + }), + ]) + + expect(finalResult.stopReason).toBe('endTurn') + expect(toolExecuted).toBe(true) + expect(model.callCount).toBe(2) + }) + + it('preserves completed tool results when interrupt fires on a later tool', async () => { + // Tools A, B, C — hook interrupts on B's BeforeToolCallEvent + // A should complete, B and C should not execute + // On resume, A is skipped, B and C execute + const model = new MockMessageModel() + .addTurn([ + { type: 'toolUseBlock', name: 'toolA', toolUseId: 'tool-a', input: {} }, + { type: 'toolUseBlock', name: 'toolB', toolUseId: 'tool-b', input: {} }, + { type: 'toolUseBlock', name: 'toolC', toolUseId: 'tool-c', input: {} }, + ]) + .addTurn({ type: 'textBlock', text: 'All done' }) + + const executionLog: string[] = [] + + const toolA = createMockTool('toolA', () => { + executionLog.push('A') + return 'A result' + }) + const toolB = createMockTool('toolB', () => { + executionLog.push('B') + return 'B result' + }) + const toolC = createMockTool('toolC', () => { + executionLog.push('C') + return 'C result' + }) + + const agent = new Agent({ model, tools: [toolA, toolB, toolC], toolExecutor: 'sequential', printer: false }) + + agent.addHook(BeforeToolCallEvent, (event) => { + if (event.toolUse.name === 'toolB') { + event.interrupt({ name: 'approve_b', reason: 'Approve B?' }) + } + }) + + const interruptResult = await agent.invoke('Run all') + + expect(interruptResult.stopReason).toBe('interrupt') + expect(executionLog).toEqual(['A']) + + // Verify pending state includes A's completed result + const pendingExecution = getPendingToolExecution(agent) + expect(Object.keys(pendingExecution!.completedToolResults)).toEqual(['tool-a']) + expect(pendingExecution!.completedToolResults['tool-a']!.toolResult.toolUseId).toBe('tool-a') + + // Resume — A should be skipped, B and C should execute + const finalResult = await agent.invoke([ + new InterruptResponseContent({ + interruptId: interruptResult.interrupts![0]!.id, + response: 'approved', + }), + ]) + + expect(finalResult.stopReason).toBe('endTurn') + expect(executionLog).toEqual(['A', 'B', 'C']) + expect(model.callCount).toBe(2) + }) + }) + + describe('interrupt from BeforeToolsEvent hook', () => { + it('returns stopReason interrupt when hook calls interrupt()', async () => { + // Model returns tool use first, then text block (following standard test pattern) + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'testTool', + toolUseId: 'tool-1', + input: {}, + }) + .addTurn({ type: 'textBlock', text: 'Should not reach this' }) + + const tool = createMockTool('testTool', () => 'Success') + + const agent = new Agent({ model, tools: [tool], printer: false }) + + agent.addHook(BeforeToolsEvent, (event) => { + event.interrupt({ name: 'batch_approval', reason: 'Approve all tools?' }) + }) + + const result = await agent.invoke('Test') + + expect(result).toMatchObject({ + stopReason: 'interrupt', + interrupts: [{ name: 'batch_approval', reason: 'Approve all tools?' }], + }) + }) + }) + + describe('resume flow - interrupt → response → continue', () => { + it('resumes tool callback execution without re-calling model', async () => { + // Turn 0: Model returns tool use (will be interrupted) + // Turn 1: Model returns final response (after tool completes on resume) + // Note: Resume skips model call and uses stored message + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'confirmTool', + toolUseId: 'tool-1', + input: { amount: 5000 }, + }) + .addTurn({ type: 'textBlock', text: 'Transfer completed' }) + + let callCount = 0 + let receivedResponse: unknown + const tool = new FunctionTool({ + name: 'confirmTool', + description: 'Tool that requires confirmation', + inputSchema: { + type: 'object', + properties: { amount: { type: 'number' } }, + }, + callback: (rawInput, context) => { + callCount++ + const input = rawInput as { amount: number } + const response = context.interrupt({ + name: 'confirm_transfer', + reason: `Confirm transfer of $${input.amount}?`, + }) + receivedResponse = response + return (response as { approved: boolean })?.approved ? 'Transfer approved' : 'Transfer denied' + }, + }) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + // First invocation - triggers interrupt + const interruptResult = await agent.invoke('Transfer $5000') + + expect(interruptResult).toMatchObject({ + stopReason: 'interrupt', + interrupts: [{ name: 'confirm_transfer', reason: 'Confirm transfer of $5000?' }], + }) + expect(callCount).toBe(1) // Tool was called once before interrupt + expect(model.callCount).toBe(1) // Model was called once + + // Resume with user response + const finalResult = await agent.invoke([ + new InterruptResponseContent({ + interruptId: interruptResult.interrupts![0]!.id, + response: { approved: true }, + }), + ]) + + expect(finalResult.stopReason).toBe('endTurn') + expect(receivedResponse).toEqual({ approved: true }) + expect(callCount).toBe(2) + expect(model.callCount).toBe(2) + + // Verify tool result was added to messages + const toolResultMessage = agent.messages.find( + (m) => m.role === 'user' && m.content.some((b) => b.type === 'toolResultBlock') + ) + expect(toolResultMessage).toBeDefined() + const toolResult = toolResultMessage?.content.find((b) => b.type === 'toolResultBlock') as + | ToolResultBlock + | undefined + expect(toolResult?.content[0]).toMatchObject({ type: 'textBlock', text: 'Transfer approved' }) + }) + + it('skips already-completed tools when resuming from partial execution', async () => { + // Scenario: Tools A, B, C where A & B succeed but C interrupts + // On resume: A & B should NOT re-execute, only C should execute + const model = new MockMessageModel() + .addTurn([ + { type: 'toolUseBlock', name: 'toolA', toolUseId: 'tool-a', input: {} }, + { type: 'toolUseBlock', name: 'toolB', toolUseId: 'tool-b', input: {} }, + { type: 'toolUseBlock', name: 'toolC', toolUseId: 'tool-c', input: {} }, + ]) + .addTurn({ type: 'textBlock', text: 'All tools completed' }) + + const executionLog: string[] = [] + + const toolA = createMockTool('toolA', () => { + executionLog.push('A') + return 'A result' + }) + + const toolB = createMockTool('toolB', () => { + executionLog.push('B') + return 'B result' + }) + + const toolC = createMockTool('toolC', (context) => { + const response = context.interrupt({ + name: 'confirm_c', + reason: 'Confirm tool C?', + }) + executionLog.push('C') + return (response as { approved: boolean })?.approved ? 'C approved' : 'C denied' + }) + + const agent = new Agent({ model, tools: [toolA, toolB, toolC], printer: false }) + + // First invocation - A & B execute, C interrupts + const interruptResult = await agent.invoke('Run all tools') + + expect(interruptResult).toMatchObject({ + stopReason: 'interrupt', + interrupts: [{ name: 'confirm_c', reason: 'Confirm tool C?' }], + }) + expect(executionLog).toEqual(['A', 'B']) + expect(model.callCount).toBe(1) + + // Resume with response for C + const finalResult = await agent.invoke([ + new InterruptResponseContent({ + interruptId: interruptResult.interrupts![0]!.id, + response: { approved: true }, + }), + ]) + + expect(finalResult.stopReason).toBe('endTurn') + expect(executionLog).toEqual(['A', 'B', 'C']) + expect(model.callCount).toBe(2) + + // Verify all tool results are present in messages + const toolResultMessage = agent.messages.find( + (m) => m.role === 'user' && m.content.filter((b) => b.type === 'toolResultBlock').length === 3 + ) + expect(toolResultMessage).toBeDefined() + }) + + it('throws TypeError when sending a new message while in interrupted state', async () => { + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'confirmTool', + toolUseId: 'tool-1', + input: {}, + }) + .addTurn({ type: 'textBlock', text: 'Different response' }) + + const tool = createMockTool('confirmTool', (context) => { + context.interrupt({ name: 'confirm', reason: 'Confirm?' }) + }) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + // First invocation - triggers interrupt + const interruptResult = await agent.invoke('First message') + expect(interruptResult).toMatchObject({ stopReason: 'interrupt' }) + + // Sending a new message instead of interrupt responses should throw + await expect(agent.invoke('Different question')).rejects.toThrow(TypeError) + await expect(agent.invoke('Different question')).rejects.toThrow('Agent is in an interrupted state') + }) + }) + + describe('error handling', () => { + it('throws error when interrupt() called on event with non-Agent implementation', async () => { + const mockLocalAgent = { id: 'mock' } as unknown as Agent + const event = new BeforeToolCallEvent({ + agent: mockLocalAgent, + toolUse: { name: 'test', toolUseId: 'id', input: {} }, + tool: undefined, + invocationState: {}, + }) + + expect(() => { + event.interrupt({ name: 'test', reason: 'test' }) + }).toThrow('Interrupt state not available') + }) + + it('throws TypeError when interrupt responses are mixed with other content blocks', async () => { + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'confirmTool', + toolUseId: 'tool-1', + input: {}, + }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool('confirmTool', (context) => { + context.interrupt({ name: 'confirm', reason: 'Confirm?' }) + }) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + // First invocation - triggers interrupt + const interruptResult = await agent.invoke('Test') + expect(interruptResult.stopReason).toBe('interrupt') + + // Resume with mixed content: interrupt response + text block + await expect( + agent.invoke([ + new InterruptResponseContent({ + interruptId: interruptResult.interrupts![0]!.id, + response: 'yes', + }), + { type: 'textBlock', text: 'extra text' }, + ] as any) + ).rejects.toThrow(TypeError) + + await expect( + agent.invoke([ + new InterruptResponseContent({ + interruptId: interruptResult.interrupts![0]!.id, + response: 'yes', + }), + { type: 'textBlock', text: 'extra text' }, + ] as any) + ).rejects.toThrow('Must resume from interrupt with a list of interruptResponse content blocks only') + }) + + it('allows pure interrupt response arrays without error', async () => { + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'confirmTool', + toolUseId: 'tool-1', + input: {}, + }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool('confirmTool', (context) => { + const response = context.interrupt({ name: 'confirm', reason: 'Confirm?' }) + return `Got: ${response}` + }) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + const interruptResult = await agent.invoke('Test') + expect(interruptResult.stopReason).toBe('interrupt') + + // Resume with pure interrupt responses — should succeed + const finalResult = await agent.invoke([ + new InterruptResponseContent({ + interruptId: interruptResult.interrupts![0]!.id, + response: 'approved', + }), + ]) + + expect(finalResult.stopReason).toBe('endTurn') + }) + }) + + describe('multiple hook interrupts', () => { + it('collects interrupts from multiple BeforeToolCallEvent hooks', async () => { + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'testTool', + toolUseId: 'tool-1', + input: {}, + }) + .addTurn({ type: 'textBlock', text: 'Should not reach this' }) + + const tool = createMockTool('testTool', () => 'Success') + + const agent = new Agent({ model, tools: [tool], printer: false }) + + agent.addHook(BeforeToolCallEvent, (event) => { + event.interrupt({ name: 'security_check', reason: 'Security review required' }) + }) + agent.addHook(BeforeToolCallEvent, (event) => { + event.interrupt({ name: 'budget_check', reason: 'Budget approval required' }) + }) + + const result = await agent.invoke('Test') + + expect(result).toMatchObject({ + stopReason: 'interrupt', + interrupts: expect.arrayContaining([ + expect.objectContaining({ name: 'security_check', reason: 'Security review required' }), + expect.objectContaining({ name: 'budget_check', reason: 'Budget approval required' }), + ]), + }) + }) + + it('collects interrupts from multiple BeforeToolsEvent hooks', async () => { + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'testTool', + toolUseId: 'tool-1', + input: {}, + }) + .addTurn({ type: 'textBlock', text: 'Should not reach this' }) + + const tool = createMockTool('testTool', () => 'Success') + + const agent = new Agent({ model, tools: [tool], printer: false }) + + agent.addHook(BeforeToolsEvent, (event) => { + event.interrupt({ name: 'approval_a', reason: 'First approval' }) + }) + agent.addHook(BeforeToolsEvent, (event) => { + event.interrupt({ name: 'approval_b', reason: 'Second approval' }) + }) + + const result = await agent.invoke('Test') + + expect(result).toMatchObject({ + stopReason: 'interrupt', + interrupts: expect.arrayContaining([ + expect.objectContaining({ name: 'approval_a', reason: 'First approval' }), + expect.objectContaining({ name: 'approval_b', reason: 'Second approval' }), + ]), + }) + }) + + it('resumes correctly after multiple interrupts are answered', async () => { + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'testTool', + toolUseId: 'tool-1', + input: {}, + }) + .addTurn({ type: 'textBlock', text: 'All approved' }) + + let securityResponse: unknown + let budgetResponse: unknown + let hookCallCount = 0 + + const tool = createMockTool('testTool', () => 'Success') + + const agent = new Agent({ model, tools: [tool], printer: false }) + + agent.addHook(BeforeToolCallEvent, (event) => { + hookCallCount++ + securityResponse = event.interrupt({ name: 'security_check', reason: 'Security review' }) + }) + agent.addHook(BeforeToolCallEvent, (event) => { + hookCallCount++ + budgetResponse = event.interrupt({ name: 'budget_check', reason: 'Budget review' }) + }) + + // First invocation — both hooks interrupt + const interruptResult = await agent.invoke('Test') + expect(interruptResult).toMatchObject({ + stopReason: 'interrupt', + interrupts: expect.arrayContaining([ + expect.objectContaining({ name: 'security_check' }), + expect.objectContaining({ name: 'budget_check' }), + ]), + }) + expect(interruptResult.interrupts).toHaveLength(2) + expect(hookCallCount).toBe(2) + expect(model.callCount).toBe(1) + + // Resume with responses for both interrupts + const finalResult = await agent.invoke( + interruptResult.interrupts!.map( + (interrupt) => + new InterruptResponseContent({ + interruptId: interrupt.id, + response: `approved:${interrupt.name}`, + }) + ) + ) + + expect(finalResult.stopReason).toBe('endTurn') + expect(model.callCount).toBe(2) + expect(securityResponse).toBe('approved:security_check') + expect(budgetResponse).toBe('approved:budget_check') + }) + }) + + describe('multi-cycle interrupts', () => { + it('interrupts again on cycle 2 after resuming from cycle 1 (BeforeToolsEvent)', async () => { + // Cycle 1: model returns tool use → hook interrupts → user resumes → tool executes + // Cycle 2: model returns another tool use → same hook should interrupt again + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-2', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool('testTool', () => 'ok') + + let interruptCount = 0 + const agent = new Agent({ model, tools: [tool], printer: false }) + + agent.addHook(BeforeToolsEvent, (event) => { + interruptCount++ + event.interrupt({ name: 'approval', reason: 'Approve?' }) + }) + + // Cycle 1: interrupt + const result1 = await agent.invoke('Go') + expect(result1).toMatchObject({ + stopReason: 'interrupt', + interrupts: [{ name: 'approval', reason: 'Approve?' }], + }) + expect(interruptCount).toBe(1) + + // Resume cycle 1 + const result2 = await agent.invoke( + result1.interrupts!.map( + (i) => + new InterruptResponseContent({ + interruptId: i.id, + response: 'yes', + }) + ) + ) + + // Cycle 2: should interrupt again, not silently pass through + expect(result2).toMatchObject({ stopReason: 'interrupt' }) + expect(interruptCount).toBe(3) + }) + }) + + describe('event contract during interrupt', () => { + it('does not fire AfterToolCallEvent when BeforeToolCallEvent interrupt triggers', async () => { + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'testTool', + toolUseId: 'tool-1', + input: {}, + }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool('testTool', () => 'Success') + const agent = new Agent({ model, tools: [tool], printer: false }) + + const firedEvents: string[] = [] + + agent.addHook(BeforeToolCallEvent, (event) => { + firedEvents.push('BeforeToolCallEvent') + event.interrupt({ name: 'confirm', reason: 'Confirm?' }) + }) + agent.addHook(AfterToolCallEvent, () => { + firedEvents.push('AfterToolCallEvent') + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('interrupt') + expect(firedEvents).toContain('BeforeToolCallEvent') + expect(firedEvents).not.toContain('AfterToolCallEvent') + }) + + it('does not fire AfterToolCallEvent when tool callback interrupts', async () => { + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'confirmTool', + toolUseId: 'tool-1', + input: {}, + }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool('confirmTool', (context) => { + context.interrupt({ name: 'confirm', reason: 'Confirm?' }) + }) + const agent = new Agent({ model, tools: [tool], printer: false }) + + const firedEvents: string[] = [] + + agent.addHook(BeforeToolCallEvent, () => { + firedEvents.push('BeforeToolCallEvent') + }) + agent.addHook(AfterToolCallEvent, () => { + firedEvents.push('AfterToolCallEvent') + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('interrupt') + expect(firedEvents).toContain('BeforeToolCallEvent') + expect(firedEvents).not.toContain('AfterToolCallEvent') + }) + }) + + describe('concurrent tool execution with interrupts', () => { + it('allows in-flight tool to complete when sibling interrupts', async () => { + // Use gated tools to prove concurrency: A completes AFTER B interrupts, + // demonstrating that the executor waits for in-flight tools. + const model = new MockMessageModel() + .addTurn([ + { type: 'toolUseBlock', name: 'toolA', toolUseId: 'tool-a', input: {} }, + { type: 'toolUseBlock', name: 'toolB', toolUseId: 'tool-b', input: {} }, + ]) + .addTurn({ type: 'textBlock', text: 'Done' }) + + let toolACompleted = false + let toolAResolve: () => void + const toolAGate = new Promise((resolve) => (toolAResolve = resolve)) + let toolAStartedResolve: () => void + const toolAStarted = new Promise((resolve) => (toolAStartedResolve = resolve)) + + const toolA = new FunctionTool({ + name: 'toolA', + description: 'Gated tool A', + inputSchema: { type: 'object', properties: {} }, + callback: async () => { + toolAStartedResolve() + await toolAGate + toolACompleted = true + return 'A done' + }, + }) + + const toolB = new FunctionTool({ + name: 'toolB', + description: 'Interrupting tool B', + inputSchema: { type: 'object', properties: {} }, + callback: (_input, context) => { + // Interrupt immediately — A is still in-flight + context!.interrupt({ name: 'confirm_b', reason: 'Approve B?' }) + return 'B done' + }, + }) + + const agent = new Agent({ + model, + tools: [toolA, toolB], + toolExecutor: 'concurrent', + printer: false, + }) + + const invocation = agent.invoke('Go') + + // Wait for A to start (proves both tools launched concurrently) + await toolAStarted + + // B has already interrupted, but A is still in-flight + expect(toolACompleted).toBe(false) + + // Release A — executor should let it finish + toolAResolve!() + const result = await invocation + + expect(result.stopReason).toBe('interrupt') + expect(toolACompleted).toBe(true) + expect(result.interrupts).toMatchObject([ + { id: expect.any(String), name: 'confirm_b', reason: 'Approve B?', source: 'tool' }, + ]) + + // Verify A's result was captured in pending state + const pendingExecution = getPendingToolExecution(agent) + expect(pendingExecution!.completedToolResults['tool-a']).toEqual({ + toolResult: { toolUseId: 'tool-a', status: 'success', content: [{ text: 'A done' }] }, + }) + }) + + it('stores completed tool results and resumes only the interrupted tool', async () => { + const model = new MockMessageModel() + .addTurn([ + { type: 'toolUseBlock', name: 'toolA', toolUseId: 'tool-a', input: {} }, + { type: 'toolUseBlock', name: 'toolB', toolUseId: 'tool-b', input: {} }, + ]) + .addTurn({ type: 'textBlock', text: 'Done' }) + + let toolAResolve: () => void + const toolAGate = new Promise((resolve) => (toolAResolve = resolve)) + const executionLog: string[] = [] + + const toolA = new FunctionTool({ + name: 'toolA', + description: 'Gated tool A', + inputSchema: { type: 'object', properties: {} }, + callback: async () => { + executionLog.push('A') + await toolAGate + return 'A result' + }, + }) + + const toolB = new FunctionTool({ + name: 'toolB', + description: 'Interrupting tool B', + inputSchema: { type: 'object', properties: {} }, + callback: (_input, context) => { + executionLog.push('B') + const response = context!.interrupt({ name: 'confirm_b', reason: 'Approve?' }) + return `B: ${response}` + }, + }) + + const agent = new Agent({ + model, + tools: [toolA, toolB], + toolExecutor: 'concurrent', + printer: false, + }) + + // Release A immediately so it completes + toolAResolve!() + const interruptResult = await agent.invoke('Go') + + expect(interruptResult.stopReason).toBe('interrupt') + expect(executionLog).toEqual(['A', 'B']) + + // Verify pending state has A's result + const pendingExecution = getPendingToolExecution(agent) + expect(Object.keys(pendingExecution!.completedToolResults)).toEqual(['tool-a']) + + // Resume — only B should re-execute + executionLog.length = 0 + const finalResult = await agent.invoke([ + { + interruptResponse: { + interruptId: interruptResult.interrupts![0]!.id, + response: 'approved', + }, + }, + ]) + + expect(finalResult.stopReason).toBe('endTurn') + expect(executionLog).toEqual(['B']) + }) + + it('handles BeforeToolCallEvent interrupt in concurrent mode', async () => { + const model = new MockMessageModel() + .addTurn([ + { type: 'toolUseBlock', name: 'toolA', toolUseId: 'tool-a', input: {} }, + { type: 'toolUseBlock', name: 'toolB', toolUseId: 'tool-b', input: {} }, + ]) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const executionLog: string[] = [] + + const toolA = new FunctionTool({ + name: 'toolA', + description: 'Tool A', + inputSchema: { type: 'object', properties: {} }, + callback: async () => { + executionLog.push('A') + return 'A result' + }, + }) + const toolB = new FunctionTool({ + name: 'toolB', + description: 'Tool B', + inputSchema: { type: 'object', properties: {} }, + callback: async () => { + executionLog.push('B') + return 'B result' + }, + }) + + const agent = new Agent({ + model, + tools: [toolA, toolB], + toolExecutor: 'concurrent', + printer: false, + }) + + agent.addHook(BeforeToolCallEvent, (event) => { + if (event.toolUse.name === 'toolB') { + event.interrupt({ name: 'approve_b', reason: 'Approve B?' }) + } + }) + + const interruptResult = await agent.invoke('Go') + + expect(interruptResult.stopReason).toBe('interrupt') + expect(interruptResult.interrupts).toMatchObject([ + { id: expect.any(String), name: 'approve_b', reason: 'Approve B?', source: 'hook' }, + ]) + // A should have executed, B should not (interrupted before execution) + expect(executionLog).toContain('A') + expect(executionLog).not.toContain('B') + }) + }) + + describe('InterruptEvent emission', () => { + it('yields one InterruptEvent per unanswered interrupt at stop, tagged with source', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'toolA', toolUseId: 'tool-a', input: {} }) + .addTurn({ type: 'textBlock', text: 'done' }) + + const toolA = createMockTool('toolA', (context) => { + context.interrupt({ name: 'confirm_tool', reason: 'ok?' }) + }) + + const agent = new Agent({ model, tools: [toolA], printer: false }) + + // Hook-raised interrupt on a different identifier, via BeforeToolCallEvent. + agent.addHook(BeforeToolCallEvent, (event) => { + if (event.toolUse.name === 'toolA') { + event.interrupt({ name: 'confirm_hook', reason: 'hook ok?' }) + } + }) + + const emittedEvents: InterruptEvent[] = [] + agent.addHook(InterruptEvent, (event) => { + emittedEvents.push(event) + }) + + const result = await agent.invoke('go') + + expect(result.stopReason).toBe('interrupt') + expect(result.interrupts).toHaveLength(emittedEvents.length) + // One event per interrupt, each tagged by its origin. Hook interrupts fire + // before tool callbacks, so the hook interrupt is the only one in this run. + for (const event of emittedEvents) { + expect(event.interrupt.source).toBe('hook') + } + }) + + it('InterruptEvent is available on the stream', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'approveMe', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'done' }) + + const tool = createMockTool('approveMe', (context) => { + context.interrupt({ name: 'approve', reason: 'please' }) + }) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + const events: InterruptEvent[] = [] + for await (const event of agent.stream('go')) { + if (event instanceof InterruptEvent) events.push(event) + } + + expect(events).toHaveLength(1) + expect(events[0]!.interrupt).toMatchObject({ name: 'approve', source: 'tool' }) + }) + }) +}) diff --git a/strands-ts/src/agent/__tests__/agent.invocation-state.test.ts b/strands-ts/src/agent/__tests__/agent.invocation-state.test.ts new file mode 100644 index 0000000000..5f38205fe4 --- /dev/null +++ b/strands-ts/src/agent/__tests__/agent.invocation-state.test.ts @@ -0,0 +1,276 @@ +import { describe, expect, it } from 'vitest' +import { Agent } from '../agent.js' +import { + AfterInvocationEvent, + AfterModelCallEvent, + AfterToolCallEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, + MessageAddedEvent, +} from '../../hooks/events.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { createMockTool } from '../../__fixtures__/tool-helpers.js' +import { ToolResultBlock, TextBlock } from '../../types/messages.js' +import type { InvocationState } from '../../types/agent.js' + +describe('invocationState', () => { + describe('round-trip', () => { + it('returns an empty object on AgentResult when no invocationState is passed', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + + const result = await agent.invoke('Hi') + + expect(result.invocationState).toEqual({}) + }) + + it('returns the passed invocationState on AgentResult', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + + const result = await agent.invoke('Hi', { invocationState: { userId: 'u-1', traceId: 't-1' } }) + + expect(result.invocationState).toEqual({ userId: 'u-1', traceId: 't-1' }) + }) + + it('preserves reference identity: caller keeps the same object they passed in', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + + const state: InvocationState = { userId: 'u-1' } + const result = await agent.invoke('Hi', { invocationState: state }) + + expect(result.invocationState).toBe(state) + }) + }) + + describe('hook mutation', () => { + it('propagates mutations from BeforeModelCallEvent to AfterModelCallEvent and AgentResult', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + + let seenInAfter: InvocationState | undefined + agent.addHook(BeforeModelCallEvent, (event) => { + event.invocationState.counter = (event.invocationState.counter as number | undefined) ?? 0 + event.invocationState.counter = (event.invocationState.counter as number) + 1 + }) + agent.addHook(AfterModelCallEvent, (event) => { + seenInAfter = event.invocationState + }) + + const result = await agent.invoke('Hi') + + expect(seenInAfter).toEqual({ counter: 1 }) + expect(result.invocationState).toEqual({ counter: 1 }) + }) + + it('shares the same invocationState object across all lifecycle events in one invocation', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + + const seen: InvocationState[] = [] + const collect = (event: { invocationState: InvocationState }): void => { + seen.push(event.invocationState) + } + + agent.addHook(BeforeInvocationEvent, collect) + agent.addHook(BeforeModelCallEvent, collect) + agent.addHook(AfterModelCallEvent, collect) + agent.addHook(MessageAddedEvent, collect) + agent.addHook(AfterInvocationEvent, collect) + + const result = await agent.invoke('Hi') + + // Every hook, plus the result, sees the same reference. + expect(seen.length).toBeGreaterThan(0) + for (const observed of seen) { + expect(observed).toBe(result.invocationState) + } + }) + }) + + describe('multi-cycle persistence', () => { + it('persists mutations across recursive agent loop cycles (tool-use scenario)', async () => { + const tool = createMockTool( + 'ping', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('pong')], + }) + ) + + const model = new MockMessageModel() + .addTurn([{ type: 'toolUseBlock', name: 'ping', toolUseId: 'tool-1', input: {} }]) + .addTurn({ type: 'textBlock', text: 'Done' }) + const agent = new Agent({ model, tools: [tool] }) + + // Write in AfterToolCallEvent during cycle 1; read in BeforeModelCallEvent during cycle 2. + let cycle2State: InvocationState | undefined + let modelCalls = 0 + agent.addHook(AfterToolCallEvent, (event) => { + event.invocationState.toolCompleted = true + }) + agent.addHook(BeforeModelCallEvent, (event) => { + modelCalls++ + if (modelCalls === 2) { + cycle2State = event.invocationState + } + }) + + const result = await agent.invoke('Run ping') + + expect(modelCalls).toBe(2) + expect(cycle2State).toEqual({ toolCompleted: true }) + expect(result.invocationState).toEqual({ toolCompleted: true }) + }) + }) + + describe('tool access', () => { + it('passes invocationState to tools via ToolContext and surfaces mutations on the result', async () => { + const tool = createMockTool('writer', () => { + throw new Error('unused') + }) + // Override stream to read/write invocationState. + // eslint-disable-next-line require-yield + tool.stream = async function* (context) { + const prev = (context.invocationState.callCount as number | undefined) ?? 0 + context.invocationState.callCount = prev + 1 + context.invocationState.lastToolSeenUserId = context.invocationState.userId + return new ToolResultBlock({ + toolUseId: context.toolUse.toolUseId, + status: 'success', + content: [new TextBlock('ok')], + }) + } + + const model = new MockMessageModel() + .addTurn([{ type: 'toolUseBlock', name: 'writer', toolUseId: 'tu-1', input: {} }]) + .addTurn({ type: 'textBlock', text: 'Done' }) + const agent = new Agent({ model, tools: [tool] }) + + const result = await agent.invoke('Run writer', { invocationState: { userId: 'u-42' } }) + + expect(result.invocationState).toEqual({ + userId: 'u-42', + callCount: 1, + lastToolSeenUserId: 'u-42', + }) + }) + }) + + describe('isolation from appState', () => { + it('does not touch agent.appState when invocationState is mutated', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, appState: { persistent: 'yes' } }) + + agent.addHook(BeforeModelCallEvent, (event) => { + event.invocationState.ephemeral = 'only-this-run' + }) + + const result = await agent.invoke('Hi', { invocationState: { requestId: 'r-1' } }) + + expect(result.invocationState).toEqual({ requestId: 'r-1', ephemeral: 'only-this-run' }) + expect(agent.appState.get('persistent')).toBe('yes') + expect(agent.appState.get('ephemeral')).toBeUndefined() + expect(agent.appState.get('requestId')).toBeUndefined() + }) + }) + + describe('across invocations', () => { + it('does not leak state between invocations on the same agent (default bag)', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'textBlock', text: 'A' }) + .addTurn({ type: 'textBlock', text: 'B' }) + const agent = new Agent({ model }) + + agent.addHook(BeforeModelCallEvent, (event) => { + event.invocationState.seen = true + }) + + const first = await agent.invoke('1') + const second = await agent.invoke('2') + + expect(first.invocationState).toEqual({ seen: true }) + expect(second.invocationState).toEqual({ seen: true }) + expect(first.invocationState).not.toBe(second.invocationState) + }) + }) + + describe('retry paths', () => { + it('preserves same invocationState reference across AfterModelCallEvent retry', async () => { + const model = new MockMessageModel() + .addTurn(new Error('transient failure')) + .addTurn({ type: 'textBlock', text: 'Success after retry' }) + const agent = new Agent({ model, printer: false }) + + let retried = false + const seen: InvocationState[] = [] + + agent.addHook(BeforeModelCallEvent, (event) => { + seen.push(event.invocationState) + event.invocationState.modelCalls = (event.invocationState.modelCalls as number | undefined) ?? 0 + event.invocationState.modelCalls = (event.invocationState.modelCalls as number) + 1 + }) + agent.addHook(AfterModelCallEvent, (event) => { + seen.push(event.invocationState) + if (!retried && event.error) { + retried = true + event.retry = true + } + }) + + const result = await agent.invoke('Test', { invocationState: { userId: 'u-1' } }) + + // Retry path was exercised: two Before + two After observations. + expect(seen.length).toBe(4) + // Every observation is the same object the caller passed in. + for (const observed of seen) { + expect(observed).toBe(result.invocationState) + } + // Mutations from the first attempt survive into the retry. + expect(result.invocationState).toEqual({ userId: 'u-1', modelCalls: 2 }) + }) + + it('preserves same invocationState reference across AfterToolCallEvent retry', async () => { + let toolCalls = 0 + const tool = createMockTool('flaky', () => { + toolCalls++ + return new ToolResultBlock({ + toolUseId: 'tu-1', + status: toolCalls === 1 ? 'error' : 'success', + content: [new TextBlock(toolCalls === 1 ? 'fail' : 'ok')], + }) + }) + + const model = new MockMessageModel() + .addTurn([{ type: 'toolUseBlock', name: 'flaky', toolUseId: 'tu-1', input: {} }]) + .addTurn({ type: 'textBlock', text: 'Done' }) + const agent = new Agent({ model, tools: [tool], printer: false }) + + let retried = false + const seen: InvocationState[] = [] + + agent.addHook(AfterToolCallEvent, (event) => { + seen.push(event.invocationState) + event.invocationState.toolAttempts = (event.invocationState.toolAttempts as number | undefined) ?? 0 + event.invocationState.toolAttempts = (event.invocationState.toolAttempts as number) + 1 + if (!retried && event.result.status === 'error') { + retried = true + event.retry = true + } + }) + + const result = await agent.invoke('Run flaky', { invocationState: { requestId: 'r-1' } }) + + // Retry fired twice: failed attempt + successful attempt. + expect(toolCalls).toBe(2) + expect(seen.length).toBe(2) + for (const observed of seen) { + expect(observed).toBe(result.invocationState) + } + expect(result.invocationState).toEqual({ requestId: 'r-1', toolAttempts: 2 }) + }) + }) +}) diff --git a/strands-ts/src/agent/__tests__/agent.model-retry.test.ts b/strands-ts/src/agent/__tests__/agent.model-retry.test.ts new file mode 100644 index 0000000000..5326d05bb6 --- /dev/null +++ b/strands-ts/src/agent/__tests__/agent.model-retry.test.ts @@ -0,0 +1,199 @@ +// End-to-end wiring test for DefaultModelRetryStrategy on the Agent constructor. +// Uses fake timers so the retry backoff never waits real wall time. + +import { describe, expect, it, vi, beforeEach, afterEach } from 'vitest' +import { Agent } from '../agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { DefaultModelRetryStrategy } from '../../retry/default-model-retry-strategy.js' +import { ModelRetryStrategy } from '../../retry/model-retry-strategy.js' +import type { RetryDecision } from '../../retry/retry-strategy.js' +import { ConstantBackoff } from '../../retry/backoff-strategy.js' +import { ModelThrottledError } from '../../errors.js' +import { AfterModelCallEvent } from '../../hooks/events.js' +import { logger } from '../../logging/logger.js' + +describe('Agent retryStrategy wiring', () => { + beforeEach(() => { + vi.useFakeTimers() + }) + afterEach(() => { + vi.useRealTimers() + }) + + it('retries model calls that throw ModelThrottledError', async () => { + const model = new MockMessageModel() + .addTurn(new ModelThrottledError('rate limited')) + .addTurn({ type: 'textBlock', text: 'ok' }) + + const agent = new Agent({ + model, + retryStrategy: new DefaultModelRetryStrategy({ + maxAttempts: 3, + backoff: new ConstantBackoff({ delayMs: 1 }), + }), + }) + + const invokePromise = agent.invoke('hi') + // Flush any pending timers the retry scheduled. + await vi.runAllTimersAsync() + const result = await invokePromise + + expect(result.lastMessage.content[0]).toEqual({ type: 'textBlock', text: 'ok' }) + }) + + it('does not retry non-throttling errors', async () => { + const model = new MockMessageModel().addTurn(new Error('boom')) + + const agent = new Agent({ + model, + retryStrategy: new DefaultModelRetryStrategy({ + maxAttempts: 3, + backoff: new ConstantBackoff({ delayMs: 1 }), + }), + }) + + const invokePromise = agent.invoke('hi') + const assertion = expect(invokePromise).rejects.toThrow('boom') + await vi.runAllTimersAsync() + await assertion + }) + + it('installs a default DefaultModelRetryStrategy when none is provided', async () => { + // With no override, two ModelThrottledErrors in a row should still succeed + // because the defaults allow multiple attempts. + const model = new MockMessageModel() + .addTurn(new ModelThrottledError('throttled 1')) + .addTurn(new ModelThrottledError('throttled 2')) + .addTurn({ type: 'textBlock', text: 'ok' }) + + const agent = new Agent({ model }) + const invokePromise = agent.invoke('hi') + await vi.runAllTimersAsync() + const result = await invokePromise + + expect(result.lastMessage.content[0]).toEqual({ type: 'textBlock', text: 'ok' }) + }) + + it('gives up once maxAttempts is exceeded', async () => { + const model = new MockMessageModel() + .addTurn(new ModelThrottledError('throttled 1')) + .addTurn(new ModelThrottledError('throttled 2')) + .addTurn(new ModelThrottledError('throttled 3')) + + const agent = new Agent({ + model, + retryStrategy: new DefaultModelRetryStrategy({ + maxAttempts: 2, + backoff: new ConstantBackoff({ delayMs: 1 }), + }), + }) + + const invokePromise = agent.invoke('hi') + const assertion = expect(invokePromise).rejects.toThrow(ModelThrottledError) + await vi.runAllTimersAsync() + await assertion + }) + + it('disables retries when retryStrategy is null', async () => { + const model = new MockMessageModel().addTurn(new ModelThrottledError('throttled')) + + const agent = new Agent({ model, retryStrategy: null }) + + const invokePromise = agent.invoke('hi') + const assertion = expect(invokePromise).rejects.toThrow(ModelThrottledError) + await vi.runAllTimersAsync() + await assertion + }) + + it('disables retries when retryStrategy is an empty array', async () => { + const model = new MockMessageModel().addTurn(new ModelThrottledError('throttled')) + + const agent = new Agent({ model, retryStrategy: [] }) + + const invokePromise = agent.invoke('hi') + const assertion = expect(invokePromise).rejects.toThrow(ModelThrottledError) + await vi.runAllTimersAsync() + await assertion + }) + + it('accepts an array of distinct retry strategy types', async () => { + // A trivial secondary strategy subclass so the two entries have different + // constructors (the default DefaultModelRetryStrategy cannot be paired + // with a second instance of itself — see the fail-fast test below). + class NoopRetryStrategy extends ModelRetryStrategy { + readonly name = 'test:noop-retry-strategy' + protected override computeRetryDecision(): RetryDecision { + return { retry: false } + } + } + + const model = new MockMessageModel() + .addTurn(new ModelThrottledError('throttled')) + .addTurn({ type: 'textBlock', text: 'ok' }) + + const primary = new DefaultModelRetryStrategy({ + maxAttempts: 3, + backoff: new ConstantBackoff({ delayMs: 1 }), + }) + + const agent = new Agent({ model, retryStrategy: [primary, new NoopRetryStrategy()] }) + const invokePromise = agent.invoke('hi') + await vi.runAllTimersAsync() + const result = await invokePromise + + expect(result.lastMessage.content[0]).toEqual({ type: 'textBlock', text: 'ok' }) + }) + + it('warns when two retry strategies of the same type are provided', () => { + const warn = vi.spyOn(logger, 'warn').mockImplementation(() => {}) + + new Agent({ + model: new MockMessageModel(), + retryStrategy: [new DefaultModelRetryStrategy(), new DefaultModelRetryStrategy()], + }) + + expect(warn).toHaveBeenCalledWith(expect.stringContaining('DefaultModelRetryStrategy')) + + warn.mockRestore() + }) + + it('respects a user hook that already set retry=true (no double wait, no double increment)', async () => { + const model = new MockMessageModel() + .addTurn(new ModelThrottledError('throttled')) + .addTurn({ type: 'textBlock', text: 'ok' }) + + const strategy = new DefaultModelRetryStrategy({ + maxAttempts: 2, // only 1 retry allowed — if our strategy also incremented, we'd exceed + backoff: new ConstantBackoff({ delayMs: 10_000 }), // huge delay — if we slept on top, test would time out + }) + + const agent = new Agent({ model, retryStrategy: strategy }) + agent.addHook(AfterModelCallEvent, (event) => { + if (event.error instanceof ModelThrottledError) { + event.retry = true + } + }) + + const invokePromise = agent.invoke('hi') + await vi.runAllTimersAsync() + const result = await invokePromise + + expect(result.lastMessage.content[0]).toEqual({ type: 'textBlock', text: 'ok' }) + }) + + it('throws if the same instance is attached to two agents', async () => { + const strategy = new DefaultModelRetryStrategy() + + const agent1 = new Agent({ + model: new MockMessageModel().addTurn({ type: 'textBlock', text: 'ok' }), + retryStrategy: strategy, + }) + await agent1.invoke('hi') + + const agent2 = new Agent({ + model: new MockMessageModel().addTurn({ type: 'textBlock', text: 'ok' }), + retryStrategy: strategy, + }) + await expect(agent2.invoke('hi')).rejects.toThrow(/already attached to another agent/) + }) +}) diff --git a/strands-ts/src/agent/__tests__/agent.stateful-model.test.ts b/strands-ts/src/agent/__tests__/agent.stateful-model.test.ts new file mode 100644 index 0000000000..875f052a6e --- /dev/null +++ b/strands-ts/src/agent/__tests__/agent.stateful-model.test.ts @@ -0,0 +1,197 @@ +import { describe, expect, it } from 'vitest' +import { Agent } from '../agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { MockSnapshotStorage } from '../../__fixtures__/mock-storage-provider.js' +import { SlidingWindowConversationManager } from '../../conversation-manager/sliding-window-conversation-manager.js' +import { NullConversationManager } from '../../conversation-manager/null-conversation-manager.js' +import { SessionManager } from '../../session/session-manager.js' +import { SNAPSHOT_SCHEMA_VERSION } from '../../types/snapshot.js' +import { Message } from '../../types/messages.js' +import type { StreamOptions } from '../../index.js' +import type { ModelStreamEvent } from '../../models/streaming.js' +import type { JSONValue } from '../../types/json.js' + +/** + * Mock model that advertises itself as stateful and records the modelState + * object it receives, so tests can verify the agent's modelState flows through. + */ +class StatefulMockModel extends MockMessageModel { + readonly receivedOptions: StreamOptions[] = [] + private readonly _responseIds: string[] + + constructor(responseIds: string[] = ['resp_1', 'resp_2', 'resp_3']) { + super() + this._responseIds = responseIds + } + + override get stateful(): boolean { + return true + } + + override async *stream(messages: Message[], options?: StreamOptions): AsyncGenerator { + this.receivedOptions.push(options ?? {}) + // Simulate that the provider captured a fresh response id on the wire. + if (options?.modelState) { + const next = this._responseIds[this.receivedOptions.length - 1] + if (next !== undefined) { + options.modelState.set('responseId', next) + } + } + yield* super.stream(messages, options) + } +} + +describe('Agent with stateful model', () => { + describe('constructor', () => { + it('throws when a conversationManager is supplied alongside a stateful model', () => { + const model = new StatefulMockModel() + expect( + () => new Agent({ model, conversationManager: new SlidingWindowConversationManager({ windowSize: 5 }) }) + ).toThrow(/stateful model/) + }) + + it('assigns NullConversationManager when the model is stateful', () => { + const model = new StatefulMockModel() + const agent = new Agent({ model, printer: false }) + // Private field; access through bracket notation to avoid making it public. + expect((agent as unknown as { _conversationManager: unknown })._conversationManager).toBeInstanceOf( + NullConversationManager + ) + }) + + it('initializes modelState as an empty store', () => { + const model = new StatefulMockModel() + const agent = new Agent({ model, printer: false }) + expect(agent.modelState.getAll()).toEqual({}) + }) + + it('hydrates modelState from AgentConfig.modelState', () => { + const model = new StatefulMockModel() + const agent = new Agent({ model, printer: false, modelState: { responseId: 'resp_restored' } }) + expect(agent.modelState.getAll()).toEqual({ responseId: 'resp_restored' }) + }) + }) + + describe('invocation', () => { + it('passes agent.modelState to the model via streamOptions.modelState', async () => { + const model = new StatefulMockModel(['resp_first']).addTurn({ type: 'textBlock', text: 'Hi' }) + const agent = new Agent({ model, printer: false }) + await agent.invoke('Hello') + expect(model.receivedOptions[0]?.modelState).toBe(agent.modelState) + expect(agent.modelState.getAll()).toEqual({ responseId: 'resp_first' }) + }) + + it('clears messages after invocation since the server holds history', async () => { + const model = new StatefulMockModel().addTurn({ type: 'textBlock', text: 'Hi there' }) + const agent = new Agent({ model, printer: false }) + await agent.invoke('First turn') + expect(agent.messages).toEqual([]) + }) + + it('clears messages before SessionManager snapshots on AfterInvocationEvent', async () => { + // Guards the ordering of ModelPlugin vs SessionManager hooks on + // AfterInvocationEvent: ModelPlugin must clear messages *before* + // SessionManager persists the snapshot, otherwise the stored snapshot + // would duplicate history that the server already owns. + const storage = new MockSnapshotStorage() + const sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + }) + const model = new StatefulMockModel().addTurn({ type: 'textBlock', text: 'reply' }) + const agent = new Agent({ id: 'agent-1', model, sessionManager, printer: false }) + + await agent.invoke('hi') + + const snapshot = await storage.loadSnapshot({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'agent-1' }, + }) + expect(snapshot).not.toBeNull() + expect((snapshot!.data as { messages: unknown[] }).messages).toEqual([]) + }) + + it('preserves modelState across invocations so previous_response_id chains', async () => { + const model = new StatefulMockModel(['resp_1', 'resp_2']) + .addTurn({ type: 'textBlock', text: 'one' }) + .addTurn({ type: 'textBlock', text: 'two' }) + const agent = new Agent({ model, printer: false }) + + await agent.invoke('turn 1') + expect(agent.modelState.getAll()).toEqual({ responseId: 'resp_1' }) + + await agent.invoke('turn 2') + expect(agent.modelState.getAll()).toEqual({ responseId: 'resp_2' }) + + // Both turns should have seen the state at invocation time. + expect(model.receivedOptions).toHaveLength(2) + }) + }) + + describe('stateless model (default)', () => { + it('does not clear messages after invocation', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + await agent.invoke('Hi') + // user message + assistant reply + expect(agent.messages.length).toBe(2) + }) + + it('uses the caller-provided conversationManager', () => { + const model = new MockMessageModel() + const convo = new SlidingWindowConversationManager({ windowSize: 7 }) + const agent = new Agent({ model, conversationManager: convo }) + expect((agent as unknown as { _conversationManager: unknown })._conversationManager).toBe(convo) + }) + }) + + describe('SessionManager restore guard', () => { + // Pre-seeds a session snapshot with messages, then verifies that SessionManager + // discards those messages on restore when the model is stateful. + async function setupStorageWithMessages(agentId: string, sessionId: string): Promise { + const storage = new MockSnapshotStorage() + await storage.saveSnapshot({ + location: { sessionId, scope: 'agent', scopeId: agentId }, + snapshotId: 'latest', + isLatest: true, + snapshot: { + scope: 'agent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: new Date().toISOString(), + data: { + messages: [{ role: 'user', content: [{ text: 'old turn' }] }] as unknown as JSONValue, + state: {}, + systemPrompt: null, + modelState: {}, + }, + appData: {}, + }, + }) + return storage + } + + it('discards restored messages when the model is stateful', async () => { + const storage = await setupStorageWithMessages('agent-1', 'session-stateful') + const sessionManager = new SessionManager({ + sessionId: 'session-stateful', + storage: { snapshot: storage }, + }) + const model = new StatefulMockModel() + const agent = new Agent({ id: 'agent-1', model, sessionManager, printer: false }) + await agent.initialize() + expect(agent.messages).toEqual([]) + }) + + it('restores messages when the model is stateless', async () => { + const storage = await setupStorageWithMessages('agent-2', 'session-stateless') + const sessionManager = new SessionManager({ + sessionId: 'session-stateless', + storage: { snapshot: storage }, + }) + const model = new MockMessageModel() + const agent = new Agent({ id: 'agent-2', model, sessionManager, printer: false }) + await agent.initialize() + expect(agent.messages).toHaveLength(1) + expect(agent.messages[0]!.role).toBe('user') + }) + }) +}) diff --git a/strands-ts/src/agent/__tests__/agent.test.ts b/strands-ts/src/agent/__tests__/agent.test.ts new file mode 100644 index 0000000000..abed0afad9 --- /dev/null +++ b/strands-ts/src/agent/__tests__/agent.test.ts @@ -0,0 +1,1929 @@ +import { describe, expect, it, vi } from 'vitest' +import { z } from 'zod' +import { Agent, type ToolList } from '../agent.js' +import { McpClient } from '../../mcp.js' +import { McpTool } from '../../tools/mcp-tool.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { collectGenerator } from '../../__fixtures__/model-test-helpers.js' +import { createMockTool, createRandomTool } from '../../__fixtures__/tool-helpers.js' +import { ConcurrentInvocationError } from '../../errors.js' +import { + MaxTokensError, + TextBlock, + CachePointBlock, + Message, + ToolUseBlock, + ToolResultBlock, + ReasoningBlock, + GuardContentBlock, + ImageBlock, + VideoBlock, + DocumentBlock, +} from '../../index.js' +import { AgentPrinter } from '../printer.js' +import { + AfterInvocationEvent, + AfterToolCallEvent, + AfterToolsEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, + BeforeToolsEvent, +} from '../../hooks/events.js' +import { BedrockModel } from '../../models/bedrock.js' +import { StructuredOutputError } from '../../errors.js' +import { expectLoopMetrics } from '../../__fixtures__/metrics-helpers.js' +import { expectAgentResult } from '../../__fixtures__/agent-helpers.js' + +describe('Agent', () => { + describe('stream', () => { + describe('basic streaming', () => { + it('returns AsyncGenerator', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + + const result = agent.stream('Test prompt') + + expect(result).toBeDefined() + expect(typeof result[Symbol.asyncIterator]).toBe('function') + }) + + it('returns AsyncGenerator that can be iterated without type errors', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + + // Ensures that the signature of agent.stream is correct + for await (const _ of agent.stream('Test prompt')) { + /* intentionally empty */ + } + }) + + it('yields AgentStreamEvent objects', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + + const { items } = await collectGenerator(agent.stream('Test prompt')) + + expect(items.length).toBeGreaterThan(0) + const firstItem = items[0] + expect(firstItem).toEqual(new BeforeInvocationEvent({ agent: agent, invocationState: {} })) + }) + + it('returns AgentResult as generator return value', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + + const { result } = await collectGenerator(agent.stream('Test prompt')) + + expect(result).toEqual( + expectAgentResult({ + stopReason: 'endTurn', + messageText: 'Hello', + cycleCount: 1, + traceCount: 1, + }) + ) + // Verify trace structure + expect(result.traces?.[0]?.children).toEqual( + expect.arrayContaining([expect.objectContaining({ name: 'stream_messages' })]) + ) + }) + }) + + describe('with tool use', () => { + it('handles tool execution flow', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Tool result processed' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success' as const, + content: [new TextBlock('Tool executed')], + }) + ) + + const agent = new Agent({ model, tools: [tool] }) + + const { items, result } = await collectGenerator(agent.stream('Use the tool')) + + // Check that tool-related events are yielded + const toolEvents = items.filter( + (event) => event.type === 'beforeToolsEvent' || event.type === 'afterToolsEvent' + ) + expect(toolEvents.length).toBeGreaterThan(0) + + // Check final result + expect(result.stopReason).toBe('endTurn') + }) + + it('yields tool-related events', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success' as const, + content: [new TextBlock('Success')], + }) + ) + + const agent = new Agent({ model, tools: [tool] }) + + const { items } = await collectGenerator(agent.stream('Test')) + + const beforeTools = items.find((e) => e.type === 'beforeToolsEvent') + const afterTools = items.find((e) => e.type === 'afterToolsEvent') + + expect(beforeTools).toEqual( + new BeforeToolsEvent({ + agent: agent, + message: new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'testTool', toolUseId: 'tool-1', input: {} })], + }), + invocationState: {}, + }) + ) + + expect(afterTools).toBeDefined() + expect(afterTools?.type).toBe('afterToolsEvent') + expect(afterTools?.message).toEqual({ + type: 'message', + role: 'user', + content: [ + { + type: 'toolResultBlock', + toolUseId: 'tool-1', + status: 'success', + content: [{ type: 'textBlock', text: 'Success' }], + }, + ], + }) + expect(afterTools).toHaveProperty('agent', agent) + }) + }) + + describe('error handling', () => { + it('throws MaxTokensError when model hits token limit', async () => { + const model = new MockMessageModel().addTurn( + { type: 'textBlock', text: 'Partial...' }, + { stopReason: 'maxTokens' } + ) + const agent = new Agent({ model }) + + await expect(async () => { + await collectGenerator(agent.stream('Test')) + }).rejects.toThrow(MaxTokensError) + }) + }) + + describe('hook error cleanup', () => { + it('fires AfterInvocationEvent when consumer breaks from stream and allows reinvocation', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success' as const, + content: [new TextBlock('ok')], + }) + ) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + const afterInvocationCallback = vi.fn() + agent.addHook(AfterInvocationEvent, afterInvocationCallback) + + for await (const event of agent.stream('Test')) { + if (event.type === 'beforeToolsEvent') { + break + } + } + + expect(afterInvocationCallback).toHaveBeenCalledOnce() + + const result = await agent.invoke('Test again') + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.content[0]).toEqual(new TextBlock('Done')) + }) + }) + }) + + describe('invoke', () => { + describe('basic invocation', () => { + it('returns Promise', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + + const result = agent.invoke('Test prompt') + + expect(result).toBeInstanceOf(Promise) + const awaited = await result + expect(awaited).toHaveProperty('stopReason') + expect(awaited).toHaveProperty('lastMessage') + }) + + it('returns correct stopReason and lastMessage', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Response text' }) + const agent = new Agent({ model }) + + const result = await agent.invoke('Test prompt') + + expect(result).toEqual( + expectAgentResult({ + stopReason: 'endTurn', + messageText: 'Response text', + cycleCount: 1, + traceCount: 1, + }) + ) + }) + + it('consumes stream events internally', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + + const result = await agent.invoke('Test') + + expect(result).toEqual( + expect.objectContaining({ + stopReason: 'endTurn', + lastMessage: expect.objectContaining({ + type: 'message', + role: 'assistant', + content: expect.arrayContaining([expect.objectContaining({ type: 'textBlock', text: 'Hello' })]), + }), + metrics: expectLoopMetrics({ cycleCount: 1 }), + }) + ) + }) + }) + + describe('with tool use', () => { + it('executes tools and returns final result', async () => { + const model = new MockMessageModel() + .addTurn( + { type: 'toolUseBlock', name: 'calc', toolUseId: 'tool-1', input: { a: 1, b: 2 } }, + { + usage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 }, + } + ) + .addTurn( + { type: 'textBlock', text: 'The answer is 3' }, + { + usage: { inputTokens: 200, outputTokens: 30, totalTokens: 230 }, + } + ) + + const tool = createMockTool( + 'calc', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success' as const, + content: [new TextBlock('3')], + }) + ) + + const agent = new Agent({ model, tools: [tool] }) + + const result = await agent.invoke('What is 1 + 2?') + + expect(result).toEqual( + expectAgentResult({ + stopReason: 'endTurn', + messageText: 'The answer is 3', + cycleCount: 2, + toolNames: ['calc'], + traceCount: 2, + usage: { inputTokens: 300, outputTokens: 80, totalTokens: 380 }, + }) + ) + // Verify detailed trace children structure + expect(result.traces?.[0]?.children).toEqual( + expect.arrayContaining([ + expect.objectContaining({ name: 'stream_messages' }), + expect.objectContaining({ name: 'Tool: calc' }), + ]) + ) + expect(result.traces?.[1]?.children).toEqual( + expect.arrayContaining([expect.objectContaining({ name: 'stream_messages' })]) + ) + }) + + it('stores cycleId in trace metadata', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'calc', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'calc', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success' as const, + content: [new TextBlock('result')], + }) + ) + + const agent = new Agent({ model, tools: [tool] }) + + const result = await agent.invoke('Test') + + expect(result.traces).toEqual([ + expect.objectContaining({ + name: 'Cycle 1', + metadata: expect.objectContaining({ cycleId: 'cycle-1' }), + }), + expect.objectContaining({ + name: 'Cycle 2', + metadata: expect.objectContaining({ cycleId: 'cycle-2' }), + }), + ]) + }) + + it('stores tool metadata in trace children', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-abc123', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-abc123', + status: 'success' as const, + content: [new TextBlock('result')], + }) + ) + + const agent = new Agent({ model, tools: [tool] }) + + const result = await agent.invoke('Test') + + expect(result.traces).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + name: 'Cycle 1', + children: expect.arrayContaining([ + expect.objectContaining({ + name: 'Tool: testTool', + metadata: expect.objectContaining({ + toolUseId: 'tool-abc123', + toolName: 'testTool', + }), + }), + ]), + }), + ]) + ) + }) + }) + + describe('error handling', () => { + it('propagates maxTokens error', async () => { + const model = new MockMessageModel().addTurn( + { type: 'textBlock', text: 'Partial' }, + { stopReason: 'maxTokens' } + ) + const agent = new Agent({ model }) + + await expect(agent.invoke('Test')).rejects.toThrow(MaxTokensError) + }) + }) + + describe('metrics on errors', () => { + it('tracks cycle count when maxTokens error occurs', async () => { + const model = new MockMessageModel() + .addTurn( + { type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }, + { + usage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 }, + } + ) + .addTurn( + { type: 'textBlock', text: 'Partial' }, + { + stopReason: 'maxTokens', + usage: { inputTokens: 80, outputTokens: 20, totalTokens: 100 }, + } + ) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success' as const, + content: [new TextBlock('Done')], + }) + ) + + const agent = new Agent({ model, tools: [tool] }) + + const meter = (agent as any)._meter + await expect(agent.invoke('Test')).rejects.toThrow(MaxTokensError) + + expect(meter.metrics.cycleCount).toBe(2) + // Only the first turn's usage is accumulated; the second turn throws + // MaxTokensError inside streamAggregated before metadata reaches updateCycle + expect(meter.metrics.accumulatedUsage).toStrictEqual({ + inputTokens: 100, + outputTokens: 50, + totalTokens: 150, + }) + expect(meter.metrics.accumulatedMetrics).toStrictEqual({ + latencyMs: expect.any(Number), + }) + expect(meter.metrics.toolMetrics).toStrictEqual({ + testTool: { + callCount: 1, + successCount: 1, + errorCount: 0, + totalTime: expect.any(Number), + }, + }) + }) + + it('collects local traces for completed cycles when error occurs mid-run', async () => { + const model = new MockMessageModel() + .addTurn( + { type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }, + { + usage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 }, + } + ) + .addTurn( + { type: 'textBlock', text: 'Partial' }, + { + stopReason: 'maxTokens', + usage: { inputTokens: 80, outputTokens: 20, totalTokens: 100 }, + } + ) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success' as const, + content: [new TextBlock('Done')], + }) + ) + + const agent = new Agent({ model, tools: [tool] }) + + const tracer = (agent as any)._tracer + await expect(agent.invoke('Test')).rejects.toThrow(MaxTokensError) + + // Cycle 1 completed (tool use), cycle 2 errored (maxTokens) + expect(tracer.localTraces).toEqual([ + expect.objectContaining({ + name: 'Cycle 1', + children: [ + expect.objectContaining({ name: 'stream_messages' }), + expect.objectContaining({ name: 'Tool: testTool' }), + ], + }), + expect.objectContaining({ + name: 'Cycle 2', + children: [expect.objectContaining({ name: 'stream_messages' })], + }), + ]) + }) + + it('tracks metrics when a hook throws an error', async () => { + const model = new MockMessageModel() + .addTurn( + { type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }, + { + usage: { inputTokens: 60, outputTokens: 25, totalTokens: 85 }, + } + ) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success' as const, + content: [new TextBlock('Result')], + }) + ) + + const agent = new Agent({ model, tools: [tool] }) + + agent.addHook(BeforeToolsEvent, () => { + throw new Error('Hook failure') + }) + + const meter = (agent as any)._meter + await expect(agent.invoke('Test')).rejects.toThrow('Hook failure') + + // The hook throws after the model returns but before tools execute, + // so the first cycle's model usage is recorded but no tool metrics exist + expect(meter.metrics.cycleCount).toBe(1) + expect(meter.metrics.accumulatedUsage).toStrictEqual({ + inputTokens: 60, + outputTokens: 25, + totalTokens: 85, + }) + expect(meter.metrics.accumulatedMetrics).toStrictEqual({ + latencyMs: expect.any(Number), + }) + expect(meter.metrics.toolMetrics).toStrictEqual({}) + }) + }) + + describe('hook error cleanup', () => { + it('fires AfterInvocationEvent when a mid-stream hook throws', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success' as const, + content: [new TextBlock('ok')], + }) + ) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + agent.addHook(AfterToolCallEvent, () => { + throw new Error('hook error') + }) + + const afterInvocationCallback = vi.fn() + agent.addHook(AfterInvocationEvent, afterInvocationCallback) + + await expect(agent.invoke('Test')).rejects.toThrow('hook error') + expect(afterInvocationCallback).toHaveBeenCalledOnce() + }) + + it('fires AfterToolsEvent when a mid-stream hook throws', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success' as const, + content: [new TextBlock('ok')], + }) + ) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + agent.addHook(AfterToolCallEvent, () => { + throw new Error('hook error') + }) + + const afterToolsCallback = vi.fn() + agent.addHook(AfterToolsEvent, afterToolsCallback) + + await expect(agent.invoke('Test')).rejects.toThrow('hook error') + expect(afterToolsCallback).toHaveBeenCalledOnce() + }) + + it('does not fire AfterInvocationEvent when BeforeInvocationEvent hook throws', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + agent.addHook(BeforeInvocationEvent, () => { + throw new Error('before hook error') + }) + + const afterInvocationCallback = vi.fn() + agent.addHook(AfterInvocationEvent, afterInvocationCallback) + + await expect(agent.invoke('Test')).rejects.toThrow('before hook error') + expect(afterInvocationCallback).not.toHaveBeenCalled() + }) + + it('does not fire AfterToolsEvent when BeforeToolsEvent hook throws', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success' as const, + content: [new TextBlock('ok')], + }) + ) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + agent.addHook(BeforeToolsEvent, () => { + throw new Error('before tools hook error') + }) + + const afterToolsCallback = vi.fn() + agent.addHook(AfterToolsEvent, afterToolsCallback) + + await expect(agent.invoke('Test')).rejects.toThrow('before tools hook error') + expect(afterToolsCallback).not.toHaveBeenCalled() + }) + }) + }) + + describe('API consistency', () => { + it('invoke() and stream() produce same final result', async () => { + const model1 = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Consistent response' }) + const model2 = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Consistent response' }) + + const agent1 = new Agent({ model: model1 }) + const agent2 = new Agent({ model: model2 }) + + const invokeResult = await agent1.invoke('Test') + const { result: streamResult } = await collectGenerator(agent2.stream('Test')) + + expect(invokeResult.stopReason).toBe(streamResult.stopReason) + expect(invokeResult.lastMessage.content).toEqual(streamResult.lastMessage.content) + }) + + it('both methods produce same result with tool use', async () => { + const createToolAndModels = () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'id', input: {} }) + .addTurn({ type: 'textBlock', text: 'Final' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'id', + status: 'success' as const, + content: [new TextBlock('Tool ran')], + }) + ) + + return { model, tool } + } + + const { model: model1, tool: tool1 } = createToolAndModels() + const { model: model2, tool: tool2 } = createToolAndModels() + + const agent1 = new Agent({ model: model1, tools: [tool1] }) + const agent2 = new Agent({ model: model2, tools: [tool2] }) + + const invokeResult = await agent1.invoke('Use tool') + const { result: streamResult } = await collectGenerator(agent2.stream('Use tool')) + + expect(invokeResult).toEqual( + expect.objectContaining({ + stopReason: streamResult.stopReason, + lastMessage: streamResult.lastMessage, + traces: streamResult.traces?.map((t) => + expect.objectContaining({ + name: t.name, + children: expect.arrayContaining( + Array(t.children.length).fill(expect.objectContaining({ name: expect.any(String) })) + ), + }) + ), + }) + ) + }) + }) + + describe('messages', () => { + it('returns array of messages', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + + expect(agent.messages).toEqual([]) + }) + + it('reflects conversation history after invoke', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Response' }) + const agent = new Agent({ model }) + + await agent.invoke('Hello') + + expect(agent.messages).toEqual([ + expect.objectContaining({ + role: 'user', + content: [{ type: 'textBlock', text: 'Hello' }], + }), + expect.objectContaining({ + role: 'assistant', + content: [{ type: 'textBlock', text: 'Response' }], + }), + ]) + }) + }) + + describe('printer configuration', () => { + it('validates output when printer is enabled', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello world' }) + + // Capture output + const outputs: string[] = [] + const mockAppender = (text: string) => outputs.push(text) + + // Create agent with custom printer for testing + const agent = new Agent({ model, printer: false }) + ;(agent as any)._printer = new AgentPrinter(mockAppender) + + await collectGenerator(agent.stream('Test')) + + // Validate that text was output + const allOutput = outputs.join('') + expect(allOutput).toContain('Hello world') + }) + + it('does not create printer when printer is false', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + expect(agent).toBeDefined() + expect((agent as any)._printer).toBeUndefined() + }) + + it('defaults to printer=true when not specified', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + + expect(agent).toBeDefined() + expect((agent as any)._printer).toBeDefined() + }) + + it('agent works correctly with printer disabled', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false }) + + const { result } = await collectGenerator(agent.stream('Test')) + + expect(result).toBeDefined() + expect(result.lastMessage.content).toEqual([{ type: 'textBlock', text: 'Hello' }]) + }) + }) + + describe('concurrency guards', () => { + it('prevents parallel invocations', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Response' }) + const agent = new Agent({ model }) + + // Test parallel invoke() calls + const invokePromise1 = agent.invoke('First') + const invokePromise2 = agent.invoke('Second') + + await expect(invokePromise2).rejects.toThrow(ConcurrentInvocationError) + await expect(invokePromise1).resolves.toBeDefined() + }) + + it('allows sequential invocations after lock is released', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'textBlock', text: 'First response' }) + .addTurn({ type: 'textBlock', text: 'Second response' }) + const agent = new Agent({ model }) + + const result1 = await agent.invoke('First') + expect(result1.lastMessage.content).toEqual([{ type: 'textBlock', text: 'First response' }]) + + const result2 = await agent.invoke('Second') + expect(result2.lastMessage.content).toEqual([{ type: 'textBlock', text: 'Second response' }]) + }) + + it('releases lock after errors and abandoned streams', async () => { + // Test error case + const model = new MockMessageModel() + .addTurn({ type: 'textBlock', text: 'Partial' }, { stopReason: 'maxTokens' }) + .addTurn({ type: 'textBlock', text: 'Success' }) + const agent = new Agent({ model }) + + await expect(agent.invoke('First')).rejects.toThrow(MaxTokensError) + + const result = await agent.invoke('Second') + expect(result.lastMessage.content).toEqual([{ type: 'textBlock', text: 'Success' }]) + }) + }) + + describe('nested tool arrays', () => { + describe('flattens nested arrays at any depth', () => { + const tool1 = createRandomTool() + const tool2 = createRandomTool() + const tool3 = createRandomTool() + + it.for([ + ['flat array', [tool1, tool2, tool3], [tool1, tool2, tool3]], + ['single tool', [tool1], [tool1]], + ['empty array', [], []], + ['single level nesting', [[tool1, tool2], tool3], [tool1, tool2, tool3]], + ['empty nested arrays', [[], tool1, []], [tool1]], + ['deeply nested', [[[tool1]], [tool2], tool3], [tool1, tool2, tool3]], + ['mixed nesting', [[tool1, [tool2]], tool3], [tool1, tool2, tool3]], + ['very deep nesting', [[[[tool1]]]], [tool1]], + ])('%i', ([, input, expected]) => { + const agent = new Agent({ tools: input as ToolList }) + expect(agent.tools).toEqual(expected) + }) + }) + + it('accepts undefined tools', () => { + const agent = new Agent({}) + + expect(agent.tools).toEqual([]) + }) + + it('catches duplicate tool names across nested arrays', () => { + const tool1 = createRandomTool('duplicate') + const tool2 = createRandomTool('duplicate') + + expect(() => new Agent({ tools: [[tool1], [tool2]] })).toThrow("Tool with name 'duplicate' already registered") + }) + }) + + describe('systemPrompt configuration', () => { + describe('when provided as string SystemPromptData', () => { + it('accepts and stores string system prompt', () => { + const agent = new Agent({ systemPrompt: 'You are a helpful assistant' }) + expect(agent).toBeDefined() + }) + }) + + describe('when provided as array SystemPromptData', () => { + it('converts TextBlockData to TextBlock', () => { + const agent = new Agent({ systemPrompt: [{ text: 'System prompt text' }] }) + expect(agent).toBeDefined() + }) + + it('converts mixed block data types', () => { + const agent = new Agent({ + systemPrompt: [{ text: 'First block' }, { cachePoint: { cacheType: 'default' } }, { text: 'Second block' }], + }) + expect(agent).toBeDefined() + }) + }) + + describe('when provided as SystemPrompt (class instances)', () => { + it('accepts array of class instances', () => { + const systemPrompt = [new TextBlock('System prompt'), new CachePointBlock({ cacheType: 'default' })] + const agent = new Agent({ systemPrompt }) + expect(agent).toBeDefined() + }) + }) + + describe('when modifying systemPrompt', () => { + it('allows systemPrompt to be set after initialization', () => { + const agent = new Agent({ systemPrompt: 'Initial prompt' }) + + agent.systemPrompt = 'Updated prompt' + + expect(agent.systemPrompt).toEqual('Updated prompt') + }) + + it('allows systemPrompt to be changed between turns', async () => { + const firstModel = new MockMessageModel().addTurn({ type: 'textBlock', text: 'First response' }) + + const streamSpy = vi.spyOn(firstModel, 'stream') + + const agent = new Agent({ model: firstModel, systemPrompt: [new TextBlock('You are a helpful assistant')] }) + + // First invocation with initial system prompt + await agent.invoke('First prompt') + expect(agent.systemPrompt).toEqual([new TextBlock('You are a helpful assistant')]) + + // Should have been called with the given promp + expect(streamSpy).toHaveBeenCalledWith( + expect.any(Array), + expect.objectContaining({ + systemPrompt: [new TextBlock('You are a helpful assistant')], + toolSpecs: [], + }) + ) + + // Change system prompt and model + agent.systemPrompt = 'You are a coding expert' + + // Second invocation should use new system prompt + streamSpy.mockReset() + await agent.invoke('Second prompt') + expect(agent.systemPrompt).toEqual('You are a coding expert') + expect(streamSpy).toHaveBeenCalledWith( + expect.any(Array), + expect.objectContaining({ + systemPrompt: 'You are a coding expert', + toolSpecs: [], + }) + ) + }) + }) + }) + + describe('model property', () => { + describe('when accessing the model field', () => { + it('returns the configured model instance', () => { + const model = new MockMessageModel() + const agent = new Agent({ model }) + + expect(agent.model).toBe(model) + }) + + it('returns default BedrockModel when no model provided', () => { + const agent = new Agent() + + expect(agent.model).toBeDefined() + expect(agent.model.constructor.name).toBe('BedrockModel') + }) + }) + + describe('when modifying the model field', () => { + it('updates the model instance', () => { + const initialModel = new MockMessageModel() + const newModel = new MockMessageModel() + const agent = new Agent({ model: initialModel }) + + agent.model = newModel + + expect(agent.model).toBe(newModel) + expect(agent.model).not.toBe(initialModel) + }) + + it('allows model change to persist across invocations', async () => { + const firstModel = new MockMessageModel().addTurn({ type: 'textBlock', text: 'First response' }) + const secondModel = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Second response' }) + const agent = new Agent({ model: firstModel }) + + // First invocation with initial model + const firstResult = await agent.invoke('First prompt') + expect(firstResult.lastMessage?.content[0]).toEqual(new TextBlock('First response')) + + // Change model + agent.model = secondModel + + // Second invocation should use new model + const secondResult = await agent.invoke('Second prompt') + expect(secondResult.lastMessage?.content[0]).toEqual(new TextBlock('Second response')) + }) + + it('successfully switches between different model providers', async () => { + const bedrockModel = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Bedrock response' }) + const openaiModel = new MockMessageModel().addTurn({ type: 'textBlock', text: 'OpenAI response' }) + const agent = new Agent({ model: bedrockModel }) + + // First invocation + const firstResult = await agent.invoke('First prompt') + expect(firstResult.lastMessage?.content[0]).toEqual(new TextBlock('Bedrock response')) + + // Switch to different provider + agent.model = openaiModel + + // Second invocation with new provider + const secondResult = await agent.invoke('Second prompt') + expect(secondResult.lastMessage?.content[0]).toEqual(new TextBlock('OpenAI response')) + }) + }) + }) + + describe('multimodal input', () => { + describe('with string input', () => { + it('creates user message with single TextBlock', async () => { + const model = new MockMessageModel().addTurn(new TextBlock('Response')) + const agent = new Agent({ model }) + + await agent.invoke('Hello') + + expect(agent.messages).toHaveLength(2) + expect(agent.messages[0]).toEqual( + new Message({ + role: 'user', + content: [new TextBlock('Hello')], + }) + ) + }) + }) + + describe('with ContentBlock[] input', () => { + it('creates single user message with single TextBlock', async () => { + const model = new MockMessageModel().addTurn(new TextBlock('Response')) + const agent = new Agent({ model }) + + await agent.invoke([new TextBlock('Hello')]) + + expect(agent.messages).toHaveLength(2) + expect(agent.messages[0]).toEqual( + new Message({ + role: 'user', + content: [new TextBlock('Hello')], + }) + ) + }) + + it('creates single user message with multiple blocks', async () => { + const model = new MockMessageModel().addTurn(new TextBlock('Response')) + const agent = new Agent({ model }) + + const contentBlocks = [new TextBlock('Analyze this'), new TextBlock('and this')] + + await agent.invoke(contentBlocks) + + expect(agent.messages).toHaveLength(2) + expect(agent.messages[0]).toEqual( + new Message({ + role: 'user', + content: contentBlocks, + }) + ) + }) + + it('supports all ContentBlock types', async () => { + const model = new MockMessageModel().addTurn(new TextBlock('Response')) + const agent = new Agent({ model }) + + const contentBlocks = [ + new TextBlock('Text content'), + new ToolUseBlock({ name: 'tool1', toolUseId: 'id-1', input: { key: 'value' } }), + new ToolResultBlock({ + toolUseId: 'id-1', + status: 'success', + content: [new TextBlock('Result')], + }), + new ReasoningBlock({ text: 'My reasoning' }), + new CachePointBlock({ cacheType: 'default' }), + new GuardContentBlock({ text: { text: 'Guard content', qualifiers: ['grounding_source'] } }), + new ImageBlock({ + format: 'png', + source: { url: 'https://example.com/image.png' }, + }), + new VideoBlock({ + format: 'mp4', + source: { location: { type: 's3', uri: 's3://bucket/video.mp4' } }, + }), + new DocumentBlock({ + format: 'pdf', + name: 'doc.pdf', + source: { bytes: new Uint8Array([1, 2, 3]) }, + }), + ] + + await agent.invoke(contentBlocks) + + expect(agent.messages).toHaveLength(2) + expect(agent.messages[0]).toEqual( + new Message({ + role: 'user', + content: contentBlocks, + }) + ) + }) + + it('handles empty ContentBlock array', async () => { + const model = new MockMessageModel().addTurn(new TextBlock('Response')) + const agent = new Agent({ model }) + + await agent.invoke([]) + + expect(agent.messages).toHaveLength(1) // Only response message added + }) + + it('accepts ContentBlockData[] and converts to ContentBlock[]', async () => { + const model = new MockMessageModel().addTurn(new TextBlock('Response')) + const agent = new Agent({ model }) + + await agent.invoke([ + { text: 'Hello from data format' }, + { + toolUse: { + name: 'testTool', + toolUseId: 'id-1', + input: { key: 'value' }, + }, + }, + { + toolResult: { + toolUseId: 'id-1', + status: 'success' as const, + content: [{ text: 'Tool result' }, { json: { result: 42 } }], + }, + }, + { reasoning: { text: 'My reasoning' } }, + { cachePoint: { cacheType: 'default' as const } }, + { guardContent: { text: { text: 'Guard text', qualifiers: ['query' as const] } } }, + { + image: { + format: 'png' as const, + source: { url: 'https://example.com/image.png' }, + }, + }, + { + video: { + format: 'mp4' as const, + source: { location: { type: 's3' as const, uri: 's3://bucket/video.mp4' } }, + }, + }, + { + document: { + format: 'pdf' as const, + name: 'doc.pdf', + source: { bytes: new Uint8Array([1, 2, 3]) }, + }, + }, + ]) + + expect(agent.messages).toHaveLength(2) + const userMessage = agent.messages[0]! + expect(userMessage.role).toBe('user') + expect(userMessage.content).toHaveLength(9) + expect(userMessage.content[0]).toEqual(new TextBlock('Hello from data format')) + expect(userMessage.content[1]).toEqual( + new ToolUseBlock({ name: 'testTool', toolUseId: 'id-1', input: { key: 'value' } }) + ) + }) + }) + + describe('with Message[] input', () => { + it('appends single message to conversation', async () => { + const model = new MockMessageModel().addTurn(new TextBlock('Response')) + const agent = new Agent({ model }) + + const userMessage = new Message({ + role: 'user', + content: [new TextBlock('Hello')], + }) + + await agent.invoke([userMessage]) + + expect(agent.messages).toHaveLength(2) + expect(agent.messages[0]).toEqual(userMessage) + }) + + it('appends multiple messages in order', async () => { + const model = new MockMessageModel().addTurn(new TextBlock('Response')) + const agent = new Agent({ model }) + + const messages = [ + new Message({ + role: 'user', + content: [new TextBlock('First message')], + }), + new Message({ + role: 'assistant', + content: [new TextBlock('Second message')], + }), + new Message({ + role: 'user', + content: [new TextBlock('Third message')], + }), + ] + + await agent.invoke(messages) + + expect(agent.messages).toHaveLength(4) // 3 input + 1 response + expect(agent.messages[0]).toEqual(messages[0]) + expect(agent.messages[1]).toEqual(messages[1]) + expect(agent.messages[2]).toEqual(messages[2]) + }) + + it('handles empty Message array', async () => { + const model = new MockMessageModel().addTurn(new TextBlock('Response')) + const agent = new Agent({ model }) + + await agent.invoke([]) + + expect(agent.messages).toHaveLength(1) // Only response message added + }) + + it('accepts MessageData[] and converts to Message[]', async () => { + const model = new MockMessageModel().addTurn(new TextBlock('Response')) + const agent = new Agent({ model }) + + const messageDataArray = [ + { + role: 'user' as const, + content: [{ text: 'First message' }], + }, + { + role: 'assistant' as const, + content: [{ text: 'Second message' }], + }, + ] + + await agent.invoke(messageDataArray) + + expect(agent.messages).toHaveLength(3) // 2 input + 1 response + expect(agent.messages[0]).toEqual( + new Message({ + role: 'user', + content: [new TextBlock('First message')], + }) + ) + expect(agent.messages[1]).toEqual( + new Message({ + role: 'assistant', + content: [new TextBlock('Second message')], + }) + ) + }) + }) + }) + + describe('model initialization', () => { + describe('when model is a string', () => { + it('creates BedrockModel with specified modelId', () => { + const agent = new Agent({ model: 'anthropic.claude-3-5-sonnet-20240620-v1:0' }) + + expect(agent.model).toBeDefined() + expect(agent.model.constructor.name).toBe('BedrockModel') + expect(agent.model.getConfig().modelId).toBe('anthropic.claude-3-5-sonnet-20240620-v1:0') + }) + + it('creates BedrockModel with custom model ID', () => { + const customModelId = 'custom.model.id' + const agent = new Agent({ model: customModelId }) + + expect(agent.model.getConfig().modelId).toBe(customModelId) + }) + }) + + describe('when model is explicit BedrockModel', () => { + it('uses provided BedrockModel instance', () => { + const explicitModel = new BedrockModel({ modelId: 'explicit-model-id' }) + const agent = new Agent({ model: explicitModel }) + + expect(agent.model).toBe(explicitModel) + expect(agent.model.getConfig().modelId).toBe('explicit-model-id') + }) + }) + + describe('when no model is provided', () => { + it('creates default BedrockModel', () => { + const agent = new Agent() + + expect(agent.model).toBeDefined() + expect(agent.model.constructor.name).toBe('BedrockModel') + }) + }) + + describe('behavior parity', () => { + it('string model behaves identically to explicit BedrockModel with same modelId', () => { + const modelId = 'anthropic.claude-3-5-sonnet-20240620-v1:0' + + // Create agent with string model ID + const agentWithString = new Agent({ model: modelId }) + + // Create agent with explicit BedrockModel + const explicitModel = new BedrockModel({ modelId }) + const agentWithExplicit = new Agent({ model: explicitModel }) + + // Both should have same modelId + expect(agentWithString.model.getConfig().modelId).toBe(agentWithExplicit.model.getConfig().modelId) + expect(agentWithString.model.getConfig().modelId).toBe(modelId) + }) + }) + }) + + describe('structured output', () => { + it('returns structured output when schema provided and tool used', async () => { + const schema = z.object({ name: z.string(), age: z.number() }) + + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'strands_structured_output', + toolUseId: 'tool-1', + input: { name: 'John', age: 30 }, + }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, structuredOutputSchema: schema }) + + const result = await agent.invoke('Test') + + expect(result.structuredOutput).toEqual({ name: 'John', age: 30 }) + expect(model.callCount).toBe(1) + }) + + it('forces structured output tool when model does not use it', async () => { + const schema = z.object({ value: z.number() }) + + const model = new MockMessageModel() + .addTurn({ type: 'textBlock', text: 'First response' }) + .addTurn({ type: 'toolUseBlock', name: 'strands_structured_output', toolUseId: 'tool-1', input: { value: 42 } }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, structuredOutputSchema: schema }) + + const result = await agent.invoke('Test') + + expect(result.structuredOutput).toEqual({ value: 42 }) + }) + + it('does not send assistant-ended conversation when forcing structured output retry', async () => { + // Regression for https://github.com/strands-agents/sdk-typescript/issues/1039 + // When the model responds with plain text instead of calling the structured output tool, + // the forced-retry model call must not see a conversation ending with an assistant message. + // Bedrock/Anthropic-family models reject assistant message prefill. + const schema = z.object({ value: z.number() }) + + const model = new MockMessageModel() + .addTurn({ type: 'textBlock', text: 'Plain text, no tool call' }) + .addTurn({ type: 'toolUseBlock', name: 'strands_structured_output', toolUseId: 'tool-1', input: { value: 42 } }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + // Snapshot the role sequence at each model call, since `messages` is passed by reference + // and mutates during the agent loop. + const roleSnapshots: string[][] = [] + const originalStream = model.stream.bind(model) + vi.spyOn(model, 'stream').mockImplementation((messages, options) => { + roleSnapshots.push(messages.map((m) => m.role)) + return originalStream(messages, options) + }) + + const agent = new Agent({ model, structuredOutputSchema: schema }) + await agent.invoke('Test') + + expect(roleSnapshots.length).toBeGreaterThanOrEqual(2) + + // The forced-retry (second) call must not see a conversation ending with an assistant turn. + const secondCallRoles = roleSnapshots[1]! + expect(secondCallRoles[secondCallRoles.length - 1]).toBe('user') + }) + + it('throws StructuredOutputError when model refuses to use tool after forcing', async () => { + const schema = z.object({ value: z.number() }) + + // Model returns text twice - once normally, once when forced + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Response' }) + + const agent = new Agent({ model, structuredOutputSchema: schema }) + + await expect(agent.invoke('Test')).rejects.toThrow(StructuredOutputError) + }) + + it('throws MaxTokensError when maxTokens reached before structured output', async () => { + const schema = z.object({ value: z.number() }) + + const model = new MockMessageModel().addTurn( + { type: 'textBlock', text: 'Partial...' }, + { stopReason: 'maxTokens' } + ) + + const agent = new Agent({ model, structuredOutputSchema: schema }) + + await expect(agent.invoke('Test')).rejects.toThrow(MaxTokensError) + }) + + it('retries with validation feedback when structured output tool returns error', async () => { + const schema = z.object({ name: z.string(), age: z.number() }) + + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'strands_structured_output', + toolUseId: 'tool-1', + input: { name: 'John', age: 'invalid' }, + }) + .addTurn({ + type: 'toolUseBlock', + name: 'strands_structured_output', + toolUseId: 'tool-2', + input: { name: 'John', age: 30 }, + }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, structuredOutputSchema: schema }) + + const result = await agent.invoke('Test') + + expect(result.structuredOutput).toEqual({ name: 'John', age: 30 }) + }) + + it('works without structured output schema', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + + const agent = new Agent({ model }) + + const result = await agent.invoke('Test') + + expect(result.structuredOutput).toBeUndefined() + }) + + it('cleans up structured output tool after invocation', async () => { + const schema = z.object({ value: z.number() }) + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'strands_structured_output', toolUseId: 'tool-1', input: { value: 42 } }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, structuredOutputSchema: schema }) + + await agent.invoke('Test') + + const toolNames = agent.tools.map((t) => t.name) + expect(toolNames).not.toContain('strands_structured_output') + }) + + it('cleans up structured output tool even when error occurs', async () => { + const schema = z.object({ value: z.number() }) + + const model = new MockMessageModel().addTurn( + { type: 'textBlock', text: 'Partial...' }, + { stopReason: 'maxTokens' } + ) + + const agent = new Agent({ model, structuredOutputSchema: schema }) + + await expect(agent.invoke('Test')).rejects.toThrow() + + const toolNames = agent.tools.map((t) => t.name) + expect(toolNames).not.toContain('strands_structured_output') + }) + + it('validates nested objects in structured output', async () => { + const schema = z.object({ + user: z.object({ + name: z.string(), + age: z.number(), + }), + }) + + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'strands_structured_output', + toolUseId: 'tool-1', + input: { user: { name: 'Alice', age: 25 } }, + }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, structuredOutputSchema: schema }) + + const result = await agent.invoke('Test') + + expect(result.structuredOutput).toEqual({ user: { name: 'Alice', age: 25 } }) + }) + + it('validates arrays in structured output', async () => { + const schema = z.object({ + items: z.array(z.string()), + }) + + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'strands_structured_output', + toolUseId: 'tool-1', + input: { items: ['a', 'b', 'c'] }, + }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, structuredOutputSchema: schema }) + + const result = await agent.invoke('Test') + + expect(result.structuredOutput).toEqual({ items: ['a', 'b', 'c'] }) + }) + + it('uses per-invocation override schema and restores constructor schema on next call', async () => { + const constructorSchema = z.object({ name: z.string() }) + const overrideSchema = z.object({ value: z.number() }) + + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'strands_structured_output', + toolUseId: 'tool-1', + input: { value: 99 }, + }) + .addTurn({ type: 'textBlock', text: 'Done' }) + .addTurn({ + type: 'toolUseBlock', + name: 'strands_structured_output', + toolUseId: 'tool-2', + input: { name: 'Bob' }, + }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, structuredOutputSchema: constructorSchema }) + + const first = await agent.invoke('First', { structuredOutputSchema: overrideSchema }) + expect(first.structuredOutput).toEqual({ value: 99 }) + + const second = await agent.invoke('Second') + expect(second.structuredOutput).toEqual({ name: 'Bob' }) + }) + + it('skips structured output extraction when AfterToolsEvent.endTurn halts the loop', async () => { + const schema = z.object({ name: z.string() }) + + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'strands_structured_output', + toolUseId: 'tool-1', + input: { name: 'John' }, + }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, structuredOutputSchema: schema }) + agent.addHook(AfterToolsEvent, (event: AfterToolsEvent) => { + event.endTurn = true + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('endTurn') + expect(result.structuredOutput).toBeUndefined() + expect(model.callCount).toBe(1) + }) + }) +}) + +describe('Agent._redactLastMessage', () => { + const redactMessage = '[REDACTED]' + + it('redacts last user message with only text blocks', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Response' }) + const agent = new Agent({ model }) + + // Add a user message + agent['messages'].push( + new Message({ + role: 'user', + content: [new TextBlock('sensitive content')], + }) + ) + + agent['_redactLastMessage'](redactMessage) + + const lastMessage = agent['messages'][agent['messages'].length - 1]! + expect(lastMessage.role).toBe('user') + expect(lastMessage.content).toHaveLength(1) + expect(lastMessage.content[0]!.type).toBe('textBlock') + expect((lastMessage.content[0] as TextBlock).text).toBe(redactMessage) + }) + + it('preserves tool result blocks with redacted content', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Response' }) + const agent = new Agent({ model }) + + // Add a user message with tool result and text blocks + agent['messages'].push( + new Message({ + role: 'user', + content: [ + new TextBlock('some text'), + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('tool result content')], + }), + new TextBlock('more text'), + new ToolResultBlock({ + toolUseId: 'tool-2', + status: 'error', + content: [new TextBlock('error content')], + }), + ], + }) + ) + + agent['_redactLastMessage'](redactMessage) + + const lastMessage = agent['messages'][agent['messages'].length - 1]! + expect(lastMessage.role).toBe('user') + expect(lastMessage.content).toHaveLength(2) + + // Only tool result blocks should remain + expect(lastMessage.content[0]!.type).toBe('toolResultBlock') + expect(lastMessage.content[1]!.type).toBe('toolResultBlock') + + // Tool result blocks should have redacted content but preserve structure + const toolResult1 = lastMessage.content[0] as ToolResultBlock + expect(toolResult1.toolUseId).toBe('tool-1') + expect(toolResult1.status).toBe('success') + expect(toolResult1.content).toHaveLength(1) + expect((toolResult1.content[0] as TextBlock).text).toBe(redactMessage) + + const toolResult2 = lastMessage.content[1] as ToolResultBlock + expect(toolResult2.toolUseId).toBe('tool-2') + expect(toolResult2.status).toBe('error') + expect(toolResult2.content).toHaveLength(1) + expect((toolResult2.content[0] as TextBlock).text).toBe(redactMessage) + }) + + it('does not redact when last message is not from user', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Response' }) + const agent = new Agent({ model }) + + // Add an assistant message + const assistantMessage = new Message({ + role: 'assistant', + content: [new TextBlock('assistant response')], + }) + agent['messages'].push(assistantMessage) + + const originalContent = assistantMessage.content + agent['_redactLastMessage'](redactMessage) + + const lastMessage = agent['messages'][agent['messages'].length - 1]! + expect(lastMessage.role).toBe('assistant') + expect(lastMessage.content).toBe(originalContent) + }) + + it('handles empty messages array gracefully', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Response' }) + const agent = new Agent({ model }) + + expect(() => agent['_redactLastMessage'](redactMessage)).not.toThrow() + expect(agent['messages']).toHaveLength(0) + }) +}) + +describe('_estimateInputTokens', () => { + function captureProjectedTokens(agent: Agent): Promise { + return new Promise((resolve) => { + agent.addHook(BeforeModelCallEvent, (event) => { + resolve(event.projectedInputTokens) + }) + }) + } + + it('uses full estimation on cold start (no prior usage metadata)', async () => { + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'Hello' }) + const countTokensSpy = vi.spyOn(model, 'countTokens') + countTokensSpy.mockResolvedValue(42) + + const agent = new Agent({ model, printer: false }) + const tokenPromise = captureProjectedTokens(agent) + await agent.invoke('Hi') + + expect(await tokenPromise).toBe(42) + expect(countTokensSpy).toHaveBeenCalledWith(expect.any(Array), expect.any(Object)) + }) + + it('uses known baseline when no new messages after last assistant', async () => { + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'Hello' }) + + const agent = new Agent({ + model, + printer: false, + messages: [ + new Message({ role: 'user', content: [new TextBlock('Hi')] }), + new Message({ + role: 'assistant', + content: [new TextBlock('Hello')], + metadata: { usage: { inputTokens: 100, outputTokens: 20, totalTokens: 120 } }, + }), + ], + }) + + // Invoke with no args — no new user message appended, so the last assistant + // message is still the final message and newMessages.length === 0 + const tokenPromise = captureProjectedTokens(agent) + await agent.invoke([]) + + // baseline = inputTokens(100) + outputTokens(20) = 120 + expect(await tokenPromise).toBe(120) + }) + + it('returns undefined projectedInputTokens when estimation fails', async () => { + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'Hello' }) + vi.spyOn(model, 'countTokens').mockRejectedValue(new Error('API unavailable')) + + const agent = new Agent({ model, printer: false }) + const tokenPromise = captureProjectedTokens(agent) + await agent.invoke('Hi') + + expect(await tokenPromise).toBeUndefined() + }) + + it('estimates delta for new messages after last assistant', async () => { + const model = new MockMessageModel() + model + .addTurn([{ type: 'toolUseBlock', name: 'test', toolUseId: 'id-1', input: {} }], { + usage: { inputTokens: 100, outputTokens: 30, totalTokens: 130 }, + }) + .addTurn({ type: 'textBlock', text: 'Done' }) + const countTokensSpy = vi.spyOn(model, 'countTokens') + countTokensSpy.mockResolvedValue(50) + + const tool = createMockTool( + 'test', + () => + new ToolResultBlock({ + toolUseId: 'id-1', + status: 'success' as const, + content: [new TextBlock('result')], + }) + ) + const agent = new Agent({ model, tools: [tool], printer: false }) + + // Capture the second BeforeModelCallEvent (after tool execution) + let callCount = 0 + const tokenPromise = new Promise((resolve) => { + agent.addHook(BeforeModelCallEvent, (event) => { + callCount++ + if (callCount === 2) resolve(event.projectedInputTokens) + }) + }) + + await agent.invoke('Use the tool') + + // baseline (100+30) + estimated delta (50) = 180 + expect(await tokenPromise).toBe(180) + expect(countTokensSpy).toHaveBeenCalled() + }) + + it('uses baseline from prior invocation on second invoke', async () => { + const model = new MockMessageModel() + model + .addTurn( + { type: 'textBlock', text: 'First response' }, + { usage: { inputTokens: 200, outputTokens: 50, totalTokens: 250 } } + ) + .addTurn({ type: 'textBlock', text: 'Second response' }) + const countTokensSpy = vi.spyOn(model, 'countTokens') + countTokensSpy.mockResolvedValue(15) + + const agent = new Agent({ model, printer: false }) + await agent.invoke('First question') + + // Second invocation — the user message "Second question" is appended after + // the assistant message with usage metadata, so it hits the baseline + delta path + const tokenPromise = captureProjectedTokens(agent) + await agent.invoke('Second question') + + // baseline (200+50) + estimated delta for new user message (15) = 265 + expect(await tokenPromise).toBe(265) + }) +}) + +describe('normalizeToolUseNames', () => { + it('replaces invalid tool-use names with INVALID_TOOL_NAME before calling model', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'ok' }) + const streamSpy = vi.spyOn(model, 'stream') + + const agent = new Agent({ + model, + printer: false, + messages: [ + new Message({ role: 'user', content: [new TextBlock('do thing')] }), + new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'bad name!', toolUseId: 'tu-1', input: {} })], + }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tu-1', + status: 'success', + content: [new TextBlock('result')], + }), + ], + }), + ], + }) + + await agent.invoke('continue') + + const sentMessages = streamSpy.mock.calls[0]?.[0] as Message[] + const sentToolUse = sentMessages + .find((m) => m.role === 'assistant')! + .content.find((b) => b.type === 'toolUseBlock') as ToolUseBlock + expect(sentToolUse).toStrictEqual(new ToolUseBlock({ name: 'INVALID_TOOL_NAME', toolUseId: 'tu-1', input: {} })) + + // Agent's stored history is not mutated. + const storedToolUse = agent.messages + .find((m) => m.role === 'assistant')! + .content.find((b) => b.type === 'toolUseBlock') as ToolUseBlock + expect(storedToolUse).toStrictEqual(new ToolUseBlock({ name: 'bad name!', toolUseId: 'tu-1', input: {} })) + }) + + it('preserves reasoningSignature on replaced tool-use blocks', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'ok' }) + const streamSpy = vi.spyOn(model, 'stream') + + const agent = new Agent({ + model, + printer: false, + messages: [ + new Message({ role: 'user', content: [new TextBlock('do thing')] }), + new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'bad!', toolUseId: 'tu-1', input: {}, reasoningSignature: 'sig-abc' })], + }), + new Message({ + role: 'user', + content: [new ToolResultBlock({ toolUseId: 'tu-1', status: 'success', content: [new TextBlock('ok')] })], + }), + ], + }) + + await agent.invoke('continue') + + const sentMessages = streamSpy.mock.calls[0]?.[0] as Message[] + const sentToolUse = sentMessages + .find((m) => m.role === 'assistant')! + .content.find((b) => b.type === 'toolUseBlock') as ToolUseBlock + expect(sentToolUse).toStrictEqual( + new ToolUseBlock({ + name: 'INVALID_TOOL_NAME', + toolUseId: 'tu-1', + input: {}, + reasoningSignature: 'sig-abc', + }) + ) + }) + + it('leaves valid names untouched', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'ok' }) + const streamSpy = vi.spyOn(model, 'stream') + + const agent = new Agent({ + model, + printer: false, + messages: [ + new Message({ role: 'user', content: [new TextBlock('do thing')] }), + new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'good_tool-1', toolUseId: 'tu-1', input: {} })], + }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tu-1', + status: 'success', + content: [new TextBlock('result')], + }), + ], + }), + ], + }) + + await agent.invoke('continue') + + const sentMessages = streamSpy.mock.calls[0]?.[0] as Message[] + const sentToolUse = sentMessages + .find((m) => m.role === 'assistant')! + .content.find((b) => b.type === 'toolUseBlock') as ToolUseBlock + expect(sentToolUse).toStrictEqual(new ToolUseBlock({ name: 'good_tool-1', toolUseId: 'tu-1', input: {} })) + }) + + describe('MCP toolsChanged integration', () => { + it('removes old tools and adds new tools when onToolsChanged fires', async () => { + const mcpClient = new McpClient({ + transport: { start: vi.fn(), send: vi.fn(), close: vi.fn() } as never, + }) + + const initialTools = [ + new McpTool({ name: 'tool_a', description: 'A', inputSchema: {}, client: mcpClient }), + new McpTool({ name: 'tool_b', description: 'B', inputSchema: {}, client: mcpClient }), + ] + vi.spyOn(mcpClient, 'listTools').mockResolvedValue(initialTools) + + let capturedCallback: ((oldTools: string[], newTools: McpTool[]) => void) | undefined + const setterSpy = vi.spyOn(McpClient.prototype, 'onToolsChanged', 'set').mockImplementation((cb) => { + capturedCallback = cb + }) + + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'done' }) + const agent = new Agent({ model, tools: [mcpClient] }) + await agent.initialize() + + expect(agent.tools.map((t) => t.name)).toEqual(['tool_a', 'tool_b']) + expect(capturedCallback).toBeDefined() + + const newTools = [ + new McpTool({ name: 'tool_b', description: 'B-updated', inputSchema: {}, client: mcpClient }), + new McpTool({ name: 'tool_c', description: 'C', inputSchema: {}, client: mcpClient }), + ] + + capturedCallback!(['tool_a', 'tool_b'], newTools) + + expect(agent.tools.map((t) => t.name)).toEqual(['tool_b', 'tool_c']) + expect(agent.tools.find((t) => t.name === 'tool_b')!.description).toBe('B-updated') + + setterSpy.mockRestore() + }) + }) +}) diff --git a/strands-ts/src/agent/__tests__/agent.tracer.test.node.ts b/strands-ts/src/agent/__tests__/agent.tracer.test.node.ts new file mode 100644 index 0000000000..a5ebbaf222 --- /dev/null +++ b/strands-ts/src/agent/__tests__/agent.tracer.test.node.ts @@ -0,0 +1,764 @@ +import { describe, expect, it, vi, beforeEach, type MockInstance } from 'vitest' +import { Agent } from '../agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { createMockTool } from '../../__fixtures__/tool-helpers.js' +import { TextBlock, ToolUseBlock, ToolResultBlock, MaxTokensError, StructuredOutputError } from '../../index.js' +import { Tracer } from '../../telemetry/tracer.js' +import { z } from 'zod' + +interface MockTracerInstance { + startAgentSpan: MockInstance + endAgentSpan: MockInstance + startAgentLoopSpan: MockInstance + endAgentLoopSpan: MockInstance + startModelInvokeSpan: MockInstance + endModelInvokeSpan: MockInstance + startToolCallSpan: MockInstance + endToolCallSpan: MockInstance + withSpanContext: MockInstance +} + +vi.mock('../../telemetry/tracer.js', () => ({ + Tracer: vi.fn(function () { + return { + startAgentSpan: vi.fn().mockReturnValue({ mock: 'agentSpan' }), + endAgentSpan: vi.fn(), + startAgentLoopSpan: vi.fn().mockReturnValue({ mock: 'loopSpan' }), + endAgentLoopSpan: vi.fn(), + startModelInvokeSpan: vi.fn().mockReturnValue({ mock: 'modelSpan' }), + endModelInvokeSpan: vi.fn(), + startToolCallSpan: vi.fn().mockReturnValue({ mock: 'toolSpan' }), + endToolCallSpan: vi.fn(), + withSpanContext: vi.fn((_span, fn) => fn()), + } + }), +})) + +function getLatestTracer(): MockTracerInstance { + return vi.mocked(Tracer).mock.results.at(-1)!.value +} + +describe('Agent tracer integration', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('constructor', () => { + it('initializes Tracer with traceAttributes from config', () => { + const traceAttributes = { 'custom.attr': 'value' } + new Agent({ traceAttributes }) + + expect(Tracer).toHaveBeenCalledWith(traceAttributes) + }) + + it('initializes Tracer without traceAttributes when not provided', () => { + new Agent() + + expect(Tracer).toHaveBeenCalledWith(undefined) + }) + }) + + describe('name and id', () => { + it('defaults name to "Strands Agent"', () => { + const agent = new Agent() + + expect(agent.name).toBe('Strands Agent') + }) + + it('uses provided name', () => { + const agent = new Agent({ name: 'My Agent' }) + + expect(agent.name).toBe('My Agent') + }) + + it('defaults id to "agent"', () => { + const agent = new Agent() + + expect(agent.id).toBe('agent') + }) + + it('uses provided id', () => { + const agent = new Agent({ id: 'custom-id-123' }) + + expect(agent.id).toBe('custom-id-123') + }) + }) + + describe('agent span lifecycle', () => { + it('starts and ends agent span on successful invocation', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, name: 'TestAgent', id: 'test-id' }) + const tracer = getLatestTracer() + + await agent.invoke('Hi') + + expect(tracer.startAgentSpan).toHaveBeenCalledTimes(1) + expect(tracer.startAgentSpan).toHaveBeenCalledWith( + expect.objectContaining({ + agentName: 'TestAgent', + agentId: 'test-id', + modelId: 'test-model', + }) + ) + expect(tracer.endAgentSpan).toHaveBeenCalledTimes(1) + expect(tracer.endAgentSpan).toHaveBeenCalledWith( + { mock: 'agentSpan' }, + expect.objectContaining({ + response: expect.objectContaining({ role: 'assistant' }), + stopReason: 'endTurn', + }) + ) + }) + + it('ends agent span with error when invocation fails', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Partial' }, { stopReason: 'maxTokens' }) + const agent = new Agent({ model }) + const tracer = getLatestTracer() + + await expect(agent.invoke('Hi')).rejects.toThrow(MaxTokensError) + + expect(tracer.startAgentSpan).toHaveBeenCalledTimes(1) + expect(tracer.endAgentSpan).toHaveBeenCalledTimes(1) + expect(tracer.endAgentSpan).toHaveBeenCalledWith( + { mock: 'agentSpan' }, + expect.objectContaining({ + error: expect.any(MaxTokensError), + }) + ) + }) + + it('includes systemPrompt in agent span when configured', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, systemPrompt: 'Be helpful' }) + const tracer = getLatestTracer() + + await agent.invoke('Hi') + + expect(tracer.startAgentSpan).toHaveBeenCalledWith( + expect.objectContaining({ + systemPrompt: 'Be helpful', + }) + ) + }) + + it('includes empty string systemPrompt in agent span', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, systemPrompt: '' }) + const tracer = getLatestTracer() + + await agent.invoke('Hi') + + expect(tracer.startAgentSpan).toHaveBeenCalledWith( + expect.objectContaining({ + systemPrompt: '', + }) + ) + }) + + it('omits systemPrompt from agent span when not configured', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + const tracer = getLatestTracer() + + await agent.invoke('Hi') + + expect(tracer.startAgentSpan).toHaveBeenCalledWith( + expect.not.objectContaining({ + systemPrompt: expect.anything(), + }) + ) + }) + + it('includes tools in agent span', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const tool = createMockTool( + 'myTool', + () => + new ToolResultBlock({ + toolUseId: 'id', + status: 'success', + content: [], + }) + ) + const agent = new Agent({ model, tools: [tool] }) + const tracer = getLatestTracer() + + await agent.invoke('Hi') + + expect(tracer.startAgentSpan).toHaveBeenCalledWith( + expect.objectContaining({ + tools: expect.arrayContaining([expect.objectContaining({ name: 'myTool' })]), + }) + ) + }) + }) + + describe('agent loop span lifecycle', () => { + it('starts and ends loop span for each cycle', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Done' }) + const agent = new Agent({ model }) + const tracer = getLatestTracer() + + await agent.invoke('Hi') + + expect(tracer.startAgentLoopSpan).toHaveBeenCalledTimes(1) + expect(tracer.startAgentLoopSpan).toHaveBeenCalledWith(expect.objectContaining({ cycleId: 'cycle-1' })) + expect(tracer.endAgentLoopSpan).toHaveBeenCalledTimes(1) + expect(tracer.endAgentLoopSpan).toHaveBeenCalledWith({ mock: 'loopSpan' }) + }) + + it('creates multiple loop spans for multi-cycle invocations', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('Result')], + }) + ) + + const agent = new Agent({ model, tools: [tool] }) + const tracer = getLatestTracer() + + await agent.invoke('Use tool') + + expect(tracer.startAgentLoopSpan).toHaveBeenCalledTimes(2) + expect(tracer.startAgentLoopSpan).toHaveBeenNthCalledWith(1, expect.objectContaining({ cycleId: 'cycle-1' })) + expect(tracer.startAgentLoopSpan).toHaveBeenNthCalledWith(2, expect.objectContaining({ cycleId: 'cycle-2' })) + expect(tracer.endAgentLoopSpan).toHaveBeenCalledTimes(2) + }) + + it('ends loop span with error when cycle fails', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Partial' }, { stopReason: 'maxTokens' }) + const agent = new Agent({ model }) + const tracer = getLatestTracer() + + await expect(agent.invoke('Hi')).rejects.toThrow(MaxTokensError) + + expect(tracer.endAgentLoopSpan).toHaveBeenCalledWith( + { mock: 'loopSpan' }, + expect.objectContaining({ error: expect.any(MaxTokensError) }) + ) + }) + + it('ends loop span for cycle where structured output forces tool choice', async () => { + const schema = z.object({ value: z.number() }) + + // Turn 1: model returns text (no tool use) → triggers forced tool choice on next cycle + // Turn 2: model uses the structured output tool → tool succeeds, early exit + const model = new MockMessageModel() + .addTurn({ type: 'textBlock', text: 'First response' }) + .addTurn({ type: 'toolUseBlock', name: 'strands_structured_output', toolUseId: 'tool-1', input: { value: 42 } }) + + const agent = new Agent({ model, structuredOutputSchema: schema }) + const tracer = getLatestTracer() + + await agent.invoke('Test') + + // Forced call gets its own cycle for accurate metrics and tracing + expect(tracer.startAgentLoopSpan).toHaveBeenCalledTimes(2) + expect(tracer.startAgentLoopSpan).toHaveBeenNthCalledWith(1, expect.objectContaining({ cycleId: 'cycle-1' })) + expect(tracer.startAgentLoopSpan).toHaveBeenNthCalledWith(2, expect.objectContaining({ cycleId: 'cycle-2' })) + expect(tracer.endAgentLoopSpan).toHaveBeenCalledTimes(2) + }) + }) + + describe('model invoke span lifecycle', () => { + it('starts and ends model span on successful model call', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + const tracer = getLatestTracer() + + await agent.invoke('Hi') + + expect(tracer.startModelInvokeSpan).toHaveBeenCalledTimes(1) + expect(tracer.startModelInvokeSpan).toHaveBeenCalledWith(expect.objectContaining({ modelId: 'test-model' })) + expect(tracer.endModelInvokeSpan).toHaveBeenCalledTimes(1) + expect(tracer.endModelInvokeSpan).toHaveBeenCalledWith( + { mock: 'modelSpan' }, + expect.objectContaining({ + output: expect.objectContaining({ role: 'assistant' }), + stopReason: 'endTurn', + }) + ) + }) + + it('ends model span with error when model call fails', async () => { + const model = new MockMessageModel().addTurn(new Error('Model failed')) + const agent = new Agent({ model }) + const tracer = getLatestTracer() + + await expect(agent.invoke('Hi')).rejects.toThrow() + + expect(tracer.endModelInvokeSpan).toHaveBeenCalledWith( + { mock: 'modelSpan' }, + expect.objectContaining({ error: expect.any(Error) }) + ) + }) + + it('creates model span for each model call in multi-cycle invocation', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('Result')], + }) + ) + + const agent = new Agent({ model, tools: [tool] }) + const tracer = getLatestTracer() + + await agent.invoke('Use tool') + + expect(tracer.startModelInvokeSpan).toHaveBeenCalledTimes(2) + expect(tracer.endModelInvokeSpan).toHaveBeenCalledTimes(2) + }) + }) + + describe('tool call span lifecycle', () => { + it('starts and ends tool span for each tool execution', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: { key: 'val' } }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('Result')], + }) + ) + + const agent = new Agent({ model, tools: [tool] }) + const tracer = getLatestTracer() + + await agent.invoke('Use tool') + + expect(tracer.startToolCallSpan).toHaveBeenCalledTimes(1) + expect(tracer.startToolCallSpan).toHaveBeenCalledWith({ + tool: expect.objectContaining({ + name: 'testTool', + toolUseId: 'tool-1', + input: { key: 'val' }, + }), + }) + expect(tracer.endToolCallSpan).toHaveBeenCalledTimes(1) + expect(tracer.endToolCallSpan).toHaveBeenCalledWith( + { mock: 'toolSpan' }, + expect.objectContaining({ + toolResult: expect.objectContaining({ toolUseId: 'tool-1', status: 'success' }), + }) + ) + }) + + it('ends tool span with error when tool is not found', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'missingTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model }) + const tracer = getLatestTracer() + + await agent.invoke('Use tool') + + expect(tracer.endToolCallSpan).toHaveBeenCalledWith( + { mock: 'toolSpan' }, + expect.objectContaining({ + toolResult: expect.objectContaining({ status: 'error' }), + }) + ) + }) + + it('ends tool span with error when tool throws', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'failTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool('failTool', () => { + throw new Error('Tool exploded') + }) + + const agent = new Agent({ model, tools: [tool] }) + const tracer = getLatestTracer() + + await agent.invoke('Use tool') + + expect(tracer.endToolCallSpan).toHaveBeenCalledWith( + { mock: 'toolSpan' }, + expect.objectContaining({ + error: expect.any(Error), + toolResult: expect.objectContaining({ status: 'error' }), + }) + ) + }) + + it('creates spans for multiple tool calls in a single turn', async () => { + const model = new MockMessageModel() + .addTurn([ + new ToolUseBlock({ name: 'tool1', toolUseId: 'id-1', input: {} }), + new ToolUseBlock({ name: 'tool2', toolUseId: 'id-2', input: {} }), + ]) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool1 = createMockTool( + 'tool1', + () => + new ToolResultBlock({ + toolUseId: 'id-1', + status: 'success', + content: [new TextBlock('R1')], + }) + ) + const tool2 = createMockTool( + 'tool2', + () => + new ToolResultBlock({ + toolUseId: 'id-2', + status: 'success', + content: [new TextBlock('R2')], + }) + ) + + const agent = new Agent({ model, tools: [tool1, tool2] }) + const tracer = getLatestTracer() + + await agent.invoke('Use tools') + + expect(tracer.startToolCallSpan).toHaveBeenCalledTimes(2) + expect(tracer.endToolCallSpan).toHaveBeenCalledTimes(2) + }) + + it('creates overlapping tool spans when toolExecutor is concurrent', async () => { + const model = new MockMessageModel() + .addTurn([ + new ToolUseBlock({ name: 'tool1', toolUseId: 'id-1', input: {} }), + new ToolUseBlock({ name: 'tool2', toolUseId: 'id-2', input: {} }), + ]) + .addTurn({ type: 'textBlock', text: 'Done' }) + + // Tools sleep briefly so the concurrent executor has time to launch both + // before either resolves. The assertions below check call order, not + // wall-clock timing. + const sleep = (ms: number) => new Promise((r) => globalThis.setTimeout(r, ms)) + // eslint-disable-next-line require-yield + async function* sleepThenReturn(toolUseId: string, text: string) { + await sleep(20) + return new ToolResultBlock({ toolUseId, status: 'success', content: [new TextBlock(text)] }) + } + const tool1 = createMockTool('tool1', () => sleepThenReturn('id-1', 'R1')) + const tool2 = createMockTool('tool2', () => sleepThenReturn('id-2', 'R2')) + + const agent = new Agent({ model, tools: [tool1, tool2], toolExecutor: 'concurrent' }) + const tracer = getLatestTracer() + + // Record span lifecycle events in order. Sequential execution would + // produce [start:A, end:A, start:B, end:B]; concurrent execution + // interleaves so both starts precede both ends. + const events: string[] = [] + tracer.startToolCallSpan.mockImplementation((args: { tool: { toolUseId: string } }) => { + events.push(`start:${args.tool.toolUseId}`) + return { mock: 'toolSpan', id: args.tool.toolUseId } + }) + tracer.endToolCallSpan.mockImplementation((span: { id: string } | null) => { + if (span && 'id' in span) events.push(`end:${span.id}`) + }) + + await agent.invoke('Use tools') + + expect(tracer.startToolCallSpan).toHaveBeenCalledTimes(2) + expect(tracer.endToolCallSpan).toHaveBeenCalledTimes(2) + // Both starts happened before either end — i.e. the spans overlap. + expect(events.slice(0, 2).sort()).toEqual(['start:id-1', 'start:id-2']) + expect(events.slice(2, 4).sort()).toEqual(['end:id-1', 'end:id-2']) + }) + }) + + describe('token usage accumulation', () => { + it('passes accumulated usage to endAgentSpan', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + const tracer = getLatestTracer() + + await agent.invoke('Hi') + + expect(tracer.endAgentSpan).toHaveBeenCalledWith( + { mock: 'agentSpan' }, + expect.objectContaining({ + accumulatedUsage: expect.objectContaining({ + inputTokens: expect.any(Number), + outputTokens: expect.any(Number), + totalTokens: expect.any(Number), + }), + }) + ) + }) + }) + + describe('null span handling', () => { + it('completes successfully when startAgentSpan returns null', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + const tracer = getLatestTracer() + tracer.startAgentSpan.mockReturnValue(null) + + const result = await agent.invoke('Hi') + + expect(result.stopReason).toBe('endTurn') + expect(tracer.endAgentSpan).toHaveBeenCalledWith(null, expect.any(Object)) + }) + + it('completes successfully when startAgentLoopSpan returns null', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + const tracer = getLatestTracer() + tracer.startAgentLoopSpan.mockReturnValue(null) + + const result = await agent.invoke('Hi') + + expect(result.stopReason).toBe('endTurn') + expect(tracer.endAgentLoopSpan).toHaveBeenCalledWith(null) + }) + + it('completes successfully when startModelInvokeSpan returns null', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + const tracer = getLatestTracer() + tracer.startModelInvokeSpan.mockReturnValue(null) + + const result = await agent.invoke('Hi') + + expect(result.stopReason).toBe('endTurn') + }) + + it('completes successfully when startToolCallSpan returns null', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'testTool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('Result')], + }) + ) + + const agent = new Agent({ model, tools: [tool] }) + const tracer = getLatestTracer() + tracer.startToolCallSpan.mockReturnValue(null) + + const result = await agent.invoke('Use tool') + + expect(result.stopReason).toBe('endTurn') + expect(tracer.endToolCallSpan).toHaveBeenCalledWith(null, expect.any(Object)) + }) + }) + + describe('span context hierarchy', () => { + it('resets accumulated usage on each invocation', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'textBlock', text: 'First' }) + .addTurn({ type: 'textBlock', text: 'Second' }) + const agent = new Agent({ model }) + const tracer = getLatestTracer() + + await agent.invoke('First') + await agent.invoke('Second') + + expect(tracer.startAgentSpan).toHaveBeenCalledTimes(2) + expect(tracer.endAgentSpan).toHaveBeenCalledTimes(2) + }) + }) + + describe('structured output and telemetry interaction', () => { + it('creates tool span for structured output tool execution', async () => { + const schema = z.object({ value: z.number() }) + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'strands_structured_output', toolUseId: 'tool-1', input: { value: 42 } }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, structuredOutputSchema: schema }) + const tracer = getLatestTracer() + + await agent.invoke('Test') + + expect(tracer.startToolCallSpan).toHaveBeenCalledWith({ + tool: expect.objectContaining({ name: 'strands_structured_output' }), + }) + expect(tracer.endToolCallSpan).toHaveBeenCalledWith( + { mock: 'toolSpan' }, + expect.objectContaining({ + toolResult: expect.objectContaining({ status: 'success' }), + }) + ) + }) + + it('ends agent span with error when model refuses structured output tool after forcing', async () => { + const schema = z.object({ value: z.number() }) + + // Single-turn model always returns text — first normally, then when forced + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'I refuse' }) + + const agent = new Agent({ model, structuredOutputSchema: schema }) + const tracer = getLatestTracer() + + await expect(agent.invoke('Test')).rejects.toThrow(StructuredOutputError) + + expect(tracer.endAgentSpan).toHaveBeenCalledWith( + { mock: 'agentSpan' }, + expect.objectContaining({ error: expect.any(StructuredOutputError) }) + ) + }) + + it('ends cycle span with error on StructuredOutputError', async () => { + const schema = z.object({ value: z.number() }) + + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'I refuse' }) + + const agent = new Agent({ model, structuredOutputSchema: schema }) + const tracer = getLatestTracer() + + await expect(agent.invoke('Test')).rejects.toThrow(StructuredOutputError) + + // Cycle 1: model returns text, triggers forced tool choice on next cycle + // Cycle 2: model still returns text, throws StructuredOutputError + expect(tracer.startAgentLoopSpan).toHaveBeenCalledTimes(2) + expect(tracer.endAgentLoopSpan).toHaveBeenCalledTimes(2) + expect(tracer.endAgentLoopSpan).toHaveBeenNthCalledWith(1, { mock: 'loopSpan' }) + expect(tracer.endAgentLoopSpan).toHaveBeenNthCalledWith( + 2, + { mock: 'loopSpan' }, + expect.objectContaining({ error: expect.any(StructuredOutputError) }) + ) + }) + + it('ends agent span with result on successful structured output', async () => { + const schema = z.object({ value: z.number() }) + + // Model calls structured output tool → early exit after successful validation + const model = new MockMessageModel().addTurn({ + type: 'toolUseBlock', + name: 'strands_structured_output', + toolUseId: 'tool-1', + input: { value: 42 }, + }) + + const agent = new Agent({ model, structuredOutputSchema: schema }) + const tracer = getLatestTracer() + + await agent.invoke('Test') + + expect(tracer.endAgentSpan).toHaveBeenCalledWith( + { mock: 'agentSpan' }, + expect.objectContaining({ + response: expect.objectContaining({ role: 'assistant' }), + stopReason: 'toolUse', + }) + ) + }) + + it('creates correct spans for validation retry cycle', async () => { + const schema = z.object({ name: z.string(), age: z.number() }) + + // Turn 1: invalid input → tool returns error, loop continues + // Turn 2: valid input → tool succeeds, early exit + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'strands_structured_output', + toolUseId: 'tool-1', + input: { name: 'John', age: 'not-a-number' }, + }) + .addTurn({ + type: 'toolUseBlock', + name: 'strands_structured_output', + toolUseId: 'tool-2', + input: { name: 'John', age: 30 }, + }) + + const agent = new Agent({ model, structuredOutputSchema: schema }) + const tracer = getLatestTracer() + + await agent.invoke('Test') + + // 2 cycles: invalid tool use, valid tool use with early exit + expect(tracer.startAgentLoopSpan).toHaveBeenCalledTimes(2) + expect(tracer.startAgentLoopSpan).toHaveBeenNthCalledWith(1, expect.objectContaining({ cycleId: 'cycle-1' })) + expect(tracer.startAgentLoopSpan).toHaveBeenNthCalledWith(2, expect.objectContaining({ cycleId: 'cycle-2' })) + expect(tracer.endAgentLoopSpan).toHaveBeenCalledTimes(2) + expect(tracer.endAgentLoopSpan).toHaveBeenNthCalledWith(1, { mock: 'loopSpan' }) + expect(tracer.endAgentLoopSpan).toHaveBeenNthCalledWith(2, { mock: 'loopSpan' }) + + // 2 model calls, one per cycle + expect(tracer.startModelInvokeSpan).toHaveBeenCalledTimes(2) + expect(tracer.endModelInvokeSpan).toHaveBeenCalledTimes(2) + for (let i = 1; i <= 2; i++) { + expect(tracer.endModelInvokeSpan).toHaveBeenNthCalledWith( + i, + { mock: 'modelSpan' }, + expect.objectContaining({ output: expect.objectContaining({ role: 'assistant' }) }) + ) + } + + // 2 tool calls: first with validation error, second succeeds + expect(tracer.startToolCallSpan).toHaveBeenCalledTimes(2) + expect(tracer.startToolCallSpan).toHaveBeenNthCalledWith(1, { + tool: expect.objectContaining({ name: 'strands_structured_output', toolUseId: 'tool-1' }), + }) + expect(tracer.startToolCallSpan).toHaveBeenNthCalledWith(2, { + tool: expect.objectContaining({ name: 'strands_structured_output', toolUseId: 'tool-2' }), + }) + expect(tracer.endToolCallSpan).toHaveBeenCalledTimes(2) + expect(tracer.endToolCallSpan).toHaveBeenNthCalledWith( + 1, + { mock: 'toolSpan' }, + expect.objectContaining({ + toolResult: expect.objectContaining({ toolUseId: 'tool-1', status: 'error' }), + }) + ) + expect(tracer.endToolCallSpan).toHaveBeenNthCalledWith( + 2, + { mock: 'toolSpan' }, + expect.objectContaining({ + toolResult: expect.objectContaining({ toolUseId: 'tool-2', status: 'success' }), + }) + ) + }) + + it('ends agent span with error on maxTokens with structured output schema', async () => { + const schema = z.object({ value: z.number() }) + + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Partial' }, { stopReason: 'maxTokens' }) + + const agent = new Agent({ model, structuredOutputSchema: schema }) + const tracer = getLatestTracer() + + await expect(agent.invoke('Test')).rejects.toThrow(MaxTokensError) + + expect(tracer.endAgentSpan).toHaveBeenCalledWith( + { mock: 'agentSpan' }, + expect.objectContaining({ error: expect.any(MaxTokensError) }) + ) + expect(tracer.endAgentLoopSpan).toHaveBeenCalledWith( + { mock: 'loopSpan' }, + expect.objectContaining({ error: expect.any(MaxTokensError) }) + ) + }) + }) +}) diff --git a/strands-ts/src/agent/__tests__/printer.test.ts b/strands-ts/src/agent/__tests__/printer.test.ts new file mode 100644 index 0000000000..625dd26d10 --- /dev/null +++ b/strands-ts/src/agent/__tests__/printer.test.ts @@ -0,0 +1,267 @@ +import { describe, expect, it } from 'vitest' +import { AgentPrinter } from '../printer.js' +import { Agent } from '../agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { collectGenerator } from '../../__fixtures__/model-test-helpers.js' +import { createMockTool } from '../../__fixtures__/tool-helpers.js' +import { TextBlock, ToolResultBlock } from '../../types/messages.js' +import { BeforeToolCallEvent, BeforeToolsEvent } from '../../hooks/events.js' + +describe('AgentPrinter', () => { + describe('end-to-end scenarios', () => { + it('prints simple text output', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello world' }) + + const outputs: string[] = [] + const mockAppender = (text: string) => outputs.push(text) + + const agent = new Agent({ model, printer: false }) + ;(agent as any)._printer = new AgentPrinter(mockAppender) + + await collectGenerator(agent.stream('Test')) + + const allOutput = outputs.join('') + expect(allOutput).toBe('Hello world\n') + }) + + it('prints reasoning content wrapped in tags', async () => { + const model = new MockMessageModel().addTurn({ type: 'reasoningBlock', text: 'Let me think' }) + + const outputs: string[] = [] + const mockAppender = (text: string) => outputs.push(text) + + const agent = new Agent({ model, printer: false }) + ;(agent as any)._printer = new AgentPrinter(mockAppender) + + await collectGenerator(agent.stream('Test')) + + const allOutput = outputs.join('') + expect(allOutput).toBe('\n💭 Reasoning:\n Let me think\n\n') + }) + + it('prints text and reasoning together', async () => { + const model = new MockMessageModel().addTurn([ + { type: 'textBlock', text: 'Answer: ' }, + { type: 'reasoningBlock', text: 'thinking' }, + ]) + + const outputs: string[] = [] + const mockAppender = (text: string) => outputs.push(text) + + const agent = new Agent({ model, printer: false }) + ;(agent as any)._printer = new AgentPrinter(mockAppender) + + await collectGenerator(agent.stream('Test')) + + const allOutput = outputs.join('') + expect(allOutput).toBe('Answer: \n💭 Reasoning:\n thinking\n\n') + }) + + it('handles newlines in reasoning content', async () => { + const model = new MockMessageModel().addTurn({ + type: 'reasoningBlock', + text: 'First line\nSecond line\nThird line', + }) + + const outputs: string[] = [] + const mockAppender = (text: string) => outputs.push(text) + + const agent = new Agent({ model, printer: false }) + ;(agent as any)._printer = new AgentPrinter(mockAppender) + + await collectGenerator(agent.stream('Test')) + + const allOutput = outputs.join('') + const expected = ` +💭 Reasoning: + First line + Second line + Third line +\n` + expect(allOutput).toBe(expected) + }) + + it('prints tool execution', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'calc', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Result: 4' }) + + const tool = createMockTool( + 'calc', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success' as const, + content: [new TextBlock('4')], + }) + ) + + const outputs: string[] = [] + const mockAppender = (text: string) => outputs.push(text) + + const agent = new Agent({ model, tools: [tool], printer: false }) + ;(agent as any)._printer = new AgentPrinter(mockAppender) + + await collectGenerator(agent.stream('Test')) + + const allOutput = outputs.join('') + expect(allOutput).toBe('\n ⏳ calc\n\n🔧 Tool #1: calc\n✓ Tool completed\nResult: 4\n') + }) + + it('prints tool error', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'bad_tool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Error handled' }) + + const tool = createMockTool( + 'bad_tool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'error' as const, + content: [new TextBlock('Failed')], + }) + ) + + const outputs: string[] = [] + const mockAppender = (text: string) => outputs.push(text) + + const agent = new Agent({ model, tools: [tool], printer: false }) + ;(agent as any)._printer = new AgentPrinter(mockAppender) + + await collectGenerator(agent.stream('Test')) + + const allOutput = outputs.join('') + expect(allOutput).toBe('\n ⏳ bad_tool\n\n🔧 Tool #1: bad_tool\n✗ Tool failed\nError handled\n') + }) + + it('prints denied tool with denied icon', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'dangerous_tool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Tool was denied' }) + + const tool = createMockTool( + 'dangerous_tool', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'error' as const, + content: [new TextBlock('denied')], + }) + ) + + const outputs: string[] = [] + const mockAppender = (text: string) => outputs.push(text) + + const agent = new Agent({ model, tools: [tool], printer: false }) + ;(agent as any)._printer = new AgentPrinter(mockAppender) + + agent.addHook(BeforeToolCallEvent, (event: BeforeToolCallEvent) => { + event.cancel = 'Tool not allowed' + }) + + await collectGenerator(agent.stream('Test')) + + const allOutput = outputs.join('') + expect(allOutput).toBe( + '\n ⏳ dangerous_tool\n\n🚫 Tool #1: dangerous_tool (denied)\n✗ Tool failed\nTool was denied\n' + ) + }) + + it('prints batch cancel notice when BeforeToolsEvent cancels', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'tool_a', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool( + 'tool_a', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success' as const, + content: [new TextBlock('ok')], + }) + ) + + const outputs: string[] = [] + const mockAppender = (text: string) => outputs.push(text) + + const agent = new Agent({ model, tools: [tool], printer: false }) + ;(agent as any)._printer = new AgentPrinter(mockAppender) + + agent.addHook(BeforeToolsEvent, (event: BeforeToolsEvent) => { + event.cancel = true + }) + + await collectGenerator(agent.stream('Test')) + + const allOutput = outputs.join('') + expect(allOutput).toBe('\n ⏳ tool_a\n\n🚫 All tools denied\n✗ Tool failed\nDone\n') + }) + + it('prints comprehensive scenario with all output types', async () => { + const model = new MockMessageModel() + .addTurn([ + { type: 'textBlock', text: 'Let me help you. ' }, + { type: 'reasoningBlock', text: 'I need to use the calculator' }, + { type: 'toolUseBlock', name: 'calculator', toolUseId: 'tool-1', input: { expr: '2+2' } }, + ]) + .addTurn([ + { type: 'textBlock', text: 'The calculation succeeded. ' }, + { type: 'reasoningBlock', text: 'Now trying validation' }, + { type: 'toolUseBlock', name: 'validator', toolUseId: 'tool-2', input: { value: 'test' } }, + ]) + .addTurn([ + { type: 'textBlock', text: 'All done. ' }, + { type: 'reasoningBlock', text: 'Task completed successfully' }, + ]) + + const calcTool = createMockTool( + 'calculator', + () => + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success' as const, + content: [new TextBlock('4')], + }) + ) + + const validatorTool = createMockTool( + 'validator', + () => + new ToolResultBlock({ + toolUseId: 'tool-2', + status: 'error' as const, + content: [new TextBlock('Validation failed')], + }) + ) + + const outputs: string[] = [] + const mockAppender = (text: string) => outputs.push(text) + + const agent = new Agent({ model, tools: [calcTool, validatorTool], printer: false }) + ;(agent as any)._printer = new AgentPrinter(mockAppender) + + await collectGenerator(agent.stream('Test')) + + const allOutput = outputs.join('') + const expected = [ + 'Let me help you. ', + '\n💭 Reasoning:\n I need to use the calculator\n', + '\n ⏳ calculator\n', + '\n🔧 Tool #1: calculator\n', + '✓ Tool completed\n', + 'The calculation succeeded. ', + '\n💭 Reasoning:\n Now trying validation\n', + '\n ⏳ validator\n', + '\n🔧 Tool #2: validator\n', + '✗ Tool failed\n', + 'All done. ', + '\n💭 Reasoning:\n Task completed successfully\n', + '\n', + ].join('') + + expect(allOutput).toBe(expected) + }) + }) +}) diff --git a/strands-ts/src/agent/__tests__/snapshot.test.ts b/strands-ts/src/agent/__tests__/snapshot.test.ts new file mode 100644 index 0000000000..68b61f7b3d --- /dev/null +++ b/strands-ts/src/agent/__tests__/snapshot.test.ts @@ -0,0 +1,568 @@ +import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest' +import { Agent } from '../agent.js' +import type { Snapshot } from '../../types/snapshot.js' +import { SNAPSHOT_SCHEMA_VERSION } from '../../types/snapshot.js' +import { + ALL_SNAPSHOT_FIELDS, + SNAPSHOT_PRESETS, + createTimestamp, + resolveSnapshotFields, + takeSnapshot, + loadSnapshot, +} from '../snapshot.js' +import { Message, TextBlock, ToolUseBlock, ToolResultBlock } from '../../types/messages.js' +import { TestModelProvider } from '../../__fixtures__/model-test-helpers.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { createMockTool } from '../../__fixtures__/tool-helpers.js' + +// Fixed timestamp for testing +const MOCK_TIMESTAMP = '2026-01-15T12:00:00.000Z' + +/** + * Helper to create a test agent with a mock model + */ +function createTestAgent(): Agent { + return new Agent({ + model: new TestModelProvider(), + tools: [], + }) +} + +describe('Snapshot API', () => { + beforeEach(() => { + vi.useFakeTimers() + vi.setSystemTime(new Date(MOCK_TIMESTAMP)) + }) + + afterEach(() => { + vi.useRealTimers() + }) + + describe('constants', () => { + it('exports snapshot constants with correct values', () => { + expect(SNAPSHOT_SCHEMA_VERSION).toBe('1.0') + expect(ALL_SNAPSHOT_FIELDS).toEqual(['messages', 'state', 'systemPrompt', 'modelState', 'interrupts']) + expect(SNAPSHOT_PRESETS).toEqual({ + session: ['messages', 'state', 'systemPrompt', 'modelState', 'interrupts'], + }) + }) + }) + + describe('createTimestamp', () => { + it('returns ISO 8601 formatted timestamp', () => { + expect(createTimestamp()).toBe(MOCK_TIMESTAMP) + }) + }) + + describe('resolveSnapshotFields', () => { + it('throws error when no fields would be included', () => { + expect(() => resolveSnapshotFields({})).toThrow('No fields to include in snapshot') + }) + + it('returns session preset fields when preset is "session"', () => { + const fields = resolveSnapshotFields({ preset: 'session' }) + expect(fields).toEqual(new Set(['messages', 'state', 'systemPrompt', 'modelState', 'interrupts'])) + }) + + it('returns explicit fields when include is specified', () => { + const fields = resolveSnapshotFields({ include: ['messages', 'state'] }) + expect(fields).toEqual(new Set(['messages', 'state'])) + }) + + it('applies exclude after preset', () => { + const fields = resolveSnapshotFields({ preset: 'session', exclude: ['state'] }) + expect(fields).toEqual(new Set(['messages', 'systemPrompt', 'modelState', 'interrupts'])) + }) + + it('throws error for invalid preset', () => { + expect(() => resolveSnapshotFields({ preset: 'invalid' as any })).toThrow('Invalid preset: invalid') + }) + + it('throws error for invalid field names', () => { + expect(() => resolveSnapshotFields({ include: ['invalidField' as any] })).toThrow( + 'Invalid snapshot field: invalidField' + ) + }) + }) + + describe('takeSnapshot', () => { + let agent: Agent + + beforeEach(() => { + agent = createTestAgent() + }) + + it('creates snapshot with session preset', () => { + agent.messages.push(new Message({ role: 'user', content: [new TextBlock('Hello')] })) + agent.appState.set('key', 'value') + agent.systemPrompt = 'Test prompt' + + const snapshot = takeSnapshot(agent, { preset: 'session' }) + + expect(snapshot).toEqual({ + scope: 'agent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: MOCK_TIMESTAMP, + data: { + messages: [{ role: 'user', content: [{ text: 'Hello' }] }], + state: { key: 'value' }, + systemPrompt: 'Test prompt', + modelState: {}, + interrupts: { interrupts: {}, activated: false }, + }, + appData: {}, + }) + }) + + it('includes appData in snapshot', () => { + const snapshot = takeSnapshot(agent, { + preset: 'session', + appData: { customKey: 'customValue' }, + }) + expect(snapshot.appData).toEqual({ customKey: 'customValue' }) + }) + + it('excludes specified fields', () => { + agent.messages.push(new Message({ role: 'user', content: [new TextBlock('Hello')] })) + agent.appState.set('key', 'value') + + const snapshot = takeSnapshot(agent, { preset: 'session', exclude: ['messages'] }) + + expect(snapshot.data.messages).toBeUndefined() + expect(snapshot.data.state).toBeDefined() + }) + }) + + describe('loadSnapshot', () => { + let agent: Agent + + beforeEach(() => { + agent = createTestAgent() + }) + + it('throws error for incompatible schema version', () => { + const snapshot: Snapshot = { + scope: 'agent', + schemaVersion: '2.0', + createdAt: createTimestamp(), + data: {}, + appData: {}, + } + + expect(() => loadSnapshot(agent, snapshot)).toThrow( + 'Unsupported snapshot schema version: 2.0. Current version: 1.0' + ) + }) + + it('throws error for wrong scope', () => { + const snapshot: Snapshot = { + scope: 'multiAgent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: createTimestamp(), + data: {}, + appData: {}, + } + + expect(() => loadSnapshot(agent, snapshot)).toThrow("Expected snapshot scope 'agent', got 'multiAgent'") + }) + + it('restores messages from snapshot', () => { + const snapshot: Snapshot = { + scope: 'agent', + schemaVersion: '1.0', + createdAt: createTimestamp(), + data: { + messages: [{ role: 'user', content: [{ text: 'Restored message' }] }], + }, + appData: {}, + } + + loadSnapshot(agent, snapshot) + + expect(agent.messages).toHaveLength(1) + expect(agent.messages[0]).toEqual(new Message({ role: 'user', content: [new TextBlock('Restored message')] })) + }) + + it('restores state from snapshot', () => { + const snapshot: Snapshot = { + scope: 'agent', + schemaVersion: '1.0', + createdAt: createTimestamp(), + data: { + state: { restoredKey: 'restoredValue' }, + }, + appData: {}, + } + + loadSnapshot(agent, snapshot) + + expect(agent.appState.get('restoredKey')).toBe('restoredValue') + }) + + it('restores systemPrompt from snapshot', () => { + const snapshot: Snapshot = { + scope: 'agent', + schemaVersion: '1.0', + createdAt: createTimestamp(), + data: { + systemPrompt: 'Restored system prompt', + }, + appData: {}, + } + + loadSnapshot(agent, snapshot) + + expect(agent.systemPrompt).toBe('Restored system prompt') + }) + + it('clears systemPrompt when snapshot has null systemPrompt (agent had no system prompt at snapshot time)', () => { + agent.systemPrompt = 'Original prompt' + + const snapshot: Snapshot = { + scope: 'agent', + schemaVersion: '1.0', + createdAt: createTimestamp(), + data: { systemPrompt: null }, + appData: {}, + } + + loadSnapshot(agent, snapshot) + + // null in snapshot means the agent had no system prompt — should be cleared + expect(agent.systemPrompt).toBeUndefined() + }) + + it('leaves systemPrompt unchanged when systemPrompt key is absent from snapshot', () => { + agent.systemPrompt = 'Original prompt' + + const snapshot: Snapshot = { + scope: 'agent', + schemaVersion: '1.0', + createdAt: createTimestamp(), + data: { messages: [] }, // systemPrompt key not present at all + appData: {}, + } + + loadSnapshot(agent, snapshot) + + // absent key means field was not snapshotted — agent prompt should be untouched + expect(agent.systemPrompt).toBe('Original prompt') + }) + + it('leaves messages unchanged when messages key is absent from snapshot', () => { + agent.messages.push(new Message({ role: 'user', content: [new TextBlock('Existing')] })) + + const snapshot: Snapshot = { + scope: 'agent', + schemaVersion: '1.0', + createdAt: createTimestamp(), + data: { state: { key: 'val' } }, // messages key not present + appData: {}, + } + + loadSnapshot(agent, snapshot) + + expect(agent.messages).toHaveLength(1) + }) + + it('leaves state unchanged when state key is absent from snapshot', () => { + agent.appState.set('existing', 'value') + + const snapshot: Snapshot = { + scope: 'agent', + schemaVersion: '1.0', + createdAt: createTimestamp(), + data: { messages: [] }, // state key not present + appData: {}, + } + + loadSnapshot(agent, snapshot) + + expect(agent.appState.get('existing')).toBe('value') + }) + }) + + describe('round-trip', () => { + let agent: Agent + + beforeEach(() => { + agent = createTestAgent() + }) + + it('preserves messages through save/load cycle', () => { + const originalMessages = [ + new Message({ role: 'user', content: [new TextBlock('Hello')] }), + new Message({ role: 'assistant', content: [new TextBlock('Hi!')] }), + ] + agent.messages.push(...originalMessages) + + const snapshot = takeSnapshot(agent, { preset: 'session' }) + + // Modify agent + agent.messages.length = 0 + agent.messages.push(new Message({ role: 'user', content: [new TextBlock('Different')] })) + + // Restore + loadSnapshot(agent, snapshot) + + expect(agent.messages).toEqual(originalMessages) + }) + + it('preserves state through save/load cycle', () => { + agent.appState.set('userId', 'user-123') + agent.appState.set('counter', 42) + + const snapshot = takeSnapshot(agent, { preset: 'session' }) + + // Modify state + agent.appState.clear() + agent.appState.set('different', 'value') + + // Restore + loadSnapshot(agent, snapshot) + + expect(agent.appState.getAll()).toEqual({ userId: 'user-123', counter: 42 }) + }) + + it('handles complex message content', () => { + const toolUseBlock = new ToolUseBlock({ + name: 'calculator', + toolUseId: 'tool-123', + input: { operation: 'add', numbers: [1, 2, 3] }, + }) + const toolResultBlock = new ToolResultBlock({ + toolUseId: 'tool-123', + status: 'success', + content: [new TextBlock('6')], + }) + const originalMessages = [ + new Message({ role: 'assistant', content: [toolUseBlock] }), + new Message({ role: 'user', content: [toolResultBlock] }), + ] + agent.messages.push(...originalMessages) + + const snapshot = takeSnapshot(agent, { include: ['messages'] }) + agent.messages.length = 0 + loadSnapshot(agent, snapshot) + + expect(agent.messages).toEqual(originalMessages) + }) + }) + + describe('JSON serialization', () => { + it('snapshot survives JSON.stringify/JSON.parse round-trip', () => { + const agent = createTestAgent() + agent.messages.push(new Message({ role: 'user', content: [new TextBlock('Hello')] })) + agent.appState.set('userId', 'user-123') + agent.systemPrompt = 'You are a helpful assistant' + + const snapshot = takeSnapshot(agent, { preset: 'session' }) + + // Serialize to JSON string and parse back + const jsonString = JSON.stringify(snapshot) + const parsed = JSON.parse(jsonString) + + // Verify structure is preserved + expect(parsed).toEqual(snapshot) + }) + + it('snapshot can be stored and retrieved as JSON string', () => { + const agent = createTestAgent() + agent.messages.push(new Message({ role: 'user', content: [new TextBlock('Test message')] })) + agent.appState.set('key', 'value') + + const snapshot = takeSnapshot(agent, { preset: 'session' }) + + // Simulate storing to a database or file as JSON + const stored = JSON.stringify(snapshot) + + // Simulate retrieving and restoring + const retrieved = JSON.parse(stored) + const newAgent = createTestAgent() + loadSnapshot(newAgent, retrieved) + + expect(newAgent.messages).toHaveLength(1) + expect(newAgent.appState.getAll()).toEqual({ key: 'value' }) + }) + }) + + describe('interrupt state round-trip', () => { + it('preserves interrupt state through snapshot and restores for resume', async () => { + // Set up agent that will interrupt + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'confirmTool', + toolUseId: 'tool-1', + input: { action: 'delete' }, + }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool('confirmTool', (context) => { + const response = context.interrupt({ name: 'confirm', reason: 'Confirm delete?' }) + return `confirmed: ${response}` + }) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + // Trigger interrupt + const interruptResult = await agent.invoke('Delete it') + expect(interruptResult.stopReason).toBe('interrupt') + expect(interruptResult.interrupts).toHaveLength(1) + + // Snapshot the interrupted agent + const snapshot = takeSnapshot(agent, { preset: 'session' }) + expect(snapshot.data.interrupts).toBeDefined() + + // Create a fresh agent and restore from snapshot + const model2 = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Done' }) + const tool2 = createMockTool('confirmTool', (context) => { + const response = context.interrupt({ name: 'confirm', reason: 'Confirm delete?' }) + return `confirmed: ${response}` + }) + const restoredAgent = new Agent({ model: model2, tools: [tool2], printer: false }) + loadSnapshot(restoredAgent, snapshot) + + // Resume from the restored agent + const finalResult = await restoredAgent.invoke([ + { + interruptResponse: { + interruptId: interruptResult.interrupts![0]!.id, + response: 'yes', + }, + }, + ]) + + expect(finalResult.stopReason).toBe('endTurn') + }) + }) +}) + +describe('Agent.takeSnapshot / Agent.loadSnapshot (public API)', () => { + beforeEach(() => { + vi.useFakeTimers() + vi.setSystemTime(new Date(MOCK_TIMESTAMP)) + }) + + afterEach(() => { + vi.useRealTimers() + }) + + it('takeSnapshot captures state and loadSnapshot restores it (round-trip)', () => { + const agent = new Agent({ model: new TestModelProvider(), tools: [], printer: false }) + agent.messages.push( + new Message({ role: 'user', content: [new TextBlock('Hello')] }), + new Message({ role: 'assistant', content: [new TextBlock('Hi!')] }) + ) + agent.appState.set('counter', 42) + agent.systemPrompt = 'Be helpful' + + const snapshot = agent.takeSnapshot({ preset: 'session' }) + + expect(snapshot).toEqual({ + scope: 'agent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: MOCK_TIMESTAMP, + data: { + messages: [ + { role: 'user', content: [{ text: 'Hello' }] }, + { role: 'assistant', content: [{ text: 'Hi!' }] }, + ], + state: { counter: 42 }, + systemPrompt: 'Be helpful', + modelState: {}, + interrupts: { interrupts: {}, activated: false }, + }, + appData: {}, + }) + + // Mutate agent state + agent.messages.length = 0 + agent.appState.clear() + agent.systemPrompt = 'Different' + + // Restore + agent.loadSnapshot(snapshot) + + expect(agent.messages).toHaveLength(2) + expect(agent.appState.get('counter')).toBe(42) + expect(agent.systemPrompt).toBe('Be helpful') + }) + + it('propagates errors from loadSnapshot for invalid snapshots', () => { + const agent = new Agent({ model: new TestModelProvider(), tools: [], printer: false }) + + expect(() => + agent.loadSnapshot({ scope: 'agent', schemaVersion: '99.0', createdAt: '', data: {}, appData: {} }) + ).toThrow('Unsupported snapshot schema version: 99.0') + + expect(() => + agent.loadSnapshot({ + scope: 'multiAgent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: '', + data: {}, + appData: {}, + }) + ).toThrow("Expected snapshot scope 'agent', got 'multiAgent'") + + expect(() => agent.takeSnapshot({})).toThrow('No fields to include in snapshot') + }) + + it('supports JSON serialization round-trip', () => { + const agent = new Agent({ model: new TestModelProvider(), tools: [], printer: false }) + agent.messages.push(new Message({ role: 'user', content: [new TextBlock('Persist me')] })) + agent.appState.set('session', 'abc') + + const snapshot = agent.takeSnapshot({ preset: 'session' }) + const json = JSON.stringify(snapshot) + const parsed = JSON.parse(json) as Snapshot + + const newAgent = new Agent({ model: new TestModelProvider(), tools: [], printer: false }) + newAgent.loadSnapshot(parsed) + + expect(newAgent.messages).toHaveLength(1) + expect(newAgent.appState.get('session')).toBe('abc') + }) + + it('preserves and restores interrupt state for resume', async () => { + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'askUser', + toolUseId: 'tool-1', + input: { question: 'proceed?' }, + }) + .addTurn({ type: 'textBlock', text: 'Completed' }) + + const tool = createMockTool('askUser', (context) => { + const answer = context.interrupt({ name: 'ask', reason: 'Need confirmation' }) + return `User said: ${answer}` + }) + + const agent = new Agent({ model, tools: [tool], printer: false }) + + // Trigger interrupt + const result = await agent.invoke('Do something') + expect(result.stopReason).toBe('interrupt') + + // Snapshot via public method + const snapshot = agent.takeSnapshot({ preset: 'session' }) + + // Restore into a fresh agent + const model2 = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Completed' }) + const tool2 = createMockTool('askUser', (context) => { + const answer = context.interrupt({ name: 'ask', reason: 'Need confirmation' }) + return `User said: ${answer}` + }) + const restored = new Agent({ model: model2, tools: [tool2], printer: false }) + restored.loadSnapshot(snapshot) + + // Resume + const finalResult = await restored.invoke([ + { interruptResponse: { interruptId: result.interrupts![0]!.id, response: 'go ahead' } }, + ]) + + expect(finalResult.stopReason).toBe('endTurn') + }) +}) diff --git a/strands-ts/src/agent/__tests__/tool-caller.test.ts b/strands-ts/src/agent/__tests__/tool-caller.test.ts new file mode 100644 index 0000000000..e98b08ac99 --- /dev/null +++ b/strands-ts/src/agent/__tests__/tool-caller.test.ts @@ -0,0 +1,626 @@ +import { describe, expect, it } from 'vitest' +import { Agent } from '../agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { createMockTool } from '../../__fixtures__/tool-helpers.js' +import { Message, ToolResultBlock, TextBlock, ToolUseBlock } from '../../types/messages.js' +import { ConcurrentInvocationError, ToolNotFoundError } from '../../errors.js' +import { ToolStreamEvent } from '../../tools/tool.js' +import type { ToolContext } from '../../tools/tool.js' + +describe('ToolCaller', () => { + describe('basic tool calling via .invoke()', () => { + it('calls a tool by name and returns the result', async () => { + const tool = createMockTool( + 'calculator', + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('8')], + }) + ) + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [tool] }) + + const result = await agent.tool.calculator!.invoke({ a: 5, b: 3 }) + + expect(result).toStrictEqual( + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('8')], + }) + ) + }) + + it('calls a tool with empty input when no input provided', async () => { + const tool = createMockTool( + 'ping', + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('pong')], + }) + ) + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [tool] }) + + const result = await agent.tool.ping!.invoke() + + expect(result).toStrictEqual( + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('pong')], + }) + ) + }) + + it('throws when tool is not found', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [] }) + + await expect(agent.tool.nonexistent!.invoke()).rejects.toThrow(ToolNotFoundError) + await expect(agent.tool.nonexistent!.invoke()).rejects.toThrow("Tool 'nonexistent' not found") + }) + }) + + describe('underscore-to-hyphen normalization', () => { + it('resolves underscore names to hyphenated tool names', async () => { + const tool = createMockTool( + 'my-tool', + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('ok')], + }) + ) + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [tool] }) + + const result = await agent.tool.my_tool!.invoke() + + expect(result.status).toBe('success') + }) + + it('prefers exact name match over normalized match', async () => { + const exactTool = createMockTool( + 'my_tool', + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('exact')], + }) + ) + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [exactTool] }) + + const result = await agent.tool.my_tool!.invoke() + + expect(result).toStrictEqual( + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('exact')], + }) + ) + }) + }) + + describe('case-insensitive name resolution', () => { + it('resolves tool names case-insensitively', async () => { + const tool = createMockTool( + 'MyTool', + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('ok')], + }) + ) + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [tool] }) + + const result = await agent.tool.mytool!.invoke() + + expect(result.status).toBe('success') + }) + + it('prefers exact match over case-insensitive match', async () => { + const exactTool = createMockTool( + 'myTool', + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('exact')], + }) + ) + const upperTool = createMockTool( + 'MYTOOL', + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('upper')], + }) + ) + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [exactTool, upperTool] }) + + const result = await agent.tool.myTool!.invoke() + + expect(result.content[0]).toStrictEqual(new TextBlock('exact')) + }) + }) + + describe('message history recording', () => { + it('records tool call in message history by default', async () => { + const tool = createMockTool( + 'calculator', + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('8')], + }) + ) + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [tool] }) + + await agent.tool.calculator!.invoke({ a: 5, b: 3 }) + + // Per TESTING.md, prefer full-object assertions over per-field checks. + // toolUseId is non-deterministic (UUID), so use expect.stringMatching. + expect(agent.messages).toEqual([ + new Message({ + role: 'assistant', + content: [ + new ToolUseBlock({ + toolUseId: expect.stringMatching(/^tooluse_/) as unknown as string, + name: 'calculator', + input: { a: 5, b: 3 }, + }), + ], + }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('8')], + }), + ], + }), + new Message({ + role: 'assistant', + content: [new TextBlock('agent.tool.calculator was called.')], + }), + ]) + }) + + it('does not record when recordDirectToolCall is false per-call', async () => { + const tool = createMockTool( + 'calculator', + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('8')], + }) + ) + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [tool] }) + + await agent.tool.calculator!.invoke({ a: 5, b: 3 }, { recordDirectToolCall: false }) + + expect(agent.messages).toHaveLength(0) + }) + + it('records when explicitly set to true per-call', async () => { + const tool = createMockTool( + 'calculator', + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('8')], + }) + ) + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [tool] }) + + await agent.tool.calculator!.invoke({ a: 5, b: 3 }, { recordDirectToolCall: true }) + + expect(agent.messages).toHaveLength(3) + }) + + it('records full input without filtering', async () => { + const tool = createMockTool( + 'my-tool', + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('ok')], + }) + ) + Object.defineProperty(tool, 'toolSpec', { + value: { + name: 'my-tool', + description: 'Tool with strict schema', + inputSchema: { + type: 'object', + properties: { + allowed: { type: 'string' }, + }, + }, + }, + }) + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [tool] }) + + await agent.tool.my_tool!.invoke({ allowed: 'yes', extra: 'also-recorded' }) + + // Input is recorded as-is — no filtering + const recToolUseBlock = agent.messages[0]!.content[0] as ToolUseBlock + expect(recToolUseBlock).toBeInstanceOf(ToolUseBlock) + expect(recToolUseBlock.input).toStrictEqual({ allowed: 'yes', extra: 'also-recorded' }) + }) + }) + + describe('concurrency protection', () => { + it('throws ConcurrentInvocationError when agent is invoking and recording is enabled', async () => { + const tool = createMockTool( + 'slow-tool', + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('done')], + }) + ) + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [tool] }) + + // Simulate the agent being in the middle of an invocation by mocking isInvoking + Object.defineProperty(agent, 'isInvoking', { get: () => true }) + + await expect(agent.tool.slow_tool!.invoke()).rejects.toThrow(ConcurrentInvocationError) + await expect(agent.tool.slow_tool!.invoke()).rejects.toThrow( + 'Direct tool call cannot be made while the agent is in the middle of an invocation' + ) + }) + + it('allows direct tool call during invocation when recordDirectToolCall is false', async () => { + const tool = createMockTool( + 'quick-tool', + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('ok')], + }) + ) + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [tool] }) + + // Simulate the agent being in the middle of an invocation + Object.defineProperty(agent, 'isInvoking', { get: () => true }) + + // Should NOT throw when recording is disabled + const result = await agent.tool.quick_tool!.invoke({}, { recordDirectToolCall: false }) + expect(result.status).toBe('success') + }) + + it('isInvoking is false on a fresh agent', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + + expect(agent.isInvoking).toBe(false) + }) + }) + + describe('tool error handling', () => { + it('propagates errors when tool throws', async () => { + const throwingTool = createMockTool('thrower', () => { + throw new Error('Boom!') + }) + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [throwingTool] }) + + await expect(agent.tool.thrower!.invoke()).rejects.toThrow('Boom!') + }) + }) + + describe('agent.tool accessor', () => { + it('is accessible as a property', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + + expect(agent.tool).toBeDefined() + }) + + it('returns same instance on multiple accesses', () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model }) + + expect(agent.tool).toBe(agent.tool) + }) + + it('returns a ToolHandle with invoke and stream methods', () => { + const tool = createMockTool( + 'calculator', + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('ok')], + }) + ) + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [tool] }) + + const handle = agent.tool.calculator! + expect(typeof handle.invoke).toBe('function') + expect(typeof handle.stream).toBe('function') + }) + }) + + describe('tool use ID generation', () => { + it('generates unique tool use IDs using crypto.randomUUID', async () => { + const tool = createMockTool( + 'id-tool', + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('ok')], + }) + ) + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [tool] }) + + await agent.tool.id_tool!.invoke() + await agent.tool.id_tool!.invoke() + + // Each call records 3 messages: [0]=assistant(toolUse), [1]=user(toolResult), [2]=assistant(ack) + // Second call: [3]=assistant(toolUse), [4]=user(toolResult), [5]=assistant(ack) + expect(agent.messages).toHaveLength(6) + + const toolUse1 = agent.messages[0]!.content[0] as ToolUseBlock + const toolUse2 = agent.messages[3]!.content[0] as ToolUseBlock + + // Verify both are ToolUseBlocks at the correct indices + expect(toolUse1).toBeInstanceOf(ToolUseBlock) + expect(toolUse2).toBeInstanceOf(ToolUseBlock) + + // Verify IDs are unique and follow the expected format + expect(toolUse1.toolUseId).toMatch(/^tooluse_/) + expect(toolUse2.toolUseId).toMatch(/^tooluse_/) + expect(toolUse1.toolUseId).not.toBe(toolUse2.toolUseId) + }) + }) + + describe('streaming via .stream()', () => { + it('yields intermediate events and returns final result', async () => { + const yields: string[] = [] + const streamingTool = { + name: 'streamer', + description: 'A tool that yields progress events', + toolSpec: { + name: 'streamer', + description: 'A tool that yields progress events', + inputSchema: { type: 'object' as const, properties: {} }, + }, + async *stream(): AsyncGenerator { + yields.push('first') + yield new ToolStreamEvent({ data: 'step 1' }) + yields.push('second') + yield new ToolStreamEvent({ data: 'step 2' }) + yields.push('third') + yield new ToolStreamEvent({ data: 'step 3' }) + return new ToolResultBlock({ + toolUseId: 'stream-id', + status: 'success', + content: [new TextBlock('complete')], + }) + }, + } + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [streamingTool] }) + + const events: ToolStreamEvent[] = [] + const gen = agent.tool.streamer!.stream() + let result = await gen.next() + while (!result.done) { + events.push(result.value) + result = await gen.next() + } + const finalResult = result.value + + expect(finalResult.status).toBe('success') + expect(finalResult.content[0]).toStrictEqual(new TextBlock('complete')) + // Verify all yields were consumed (generator fully iterated) + expect(yields).toStrictEqual(['first', 'second', 'third']) + // Verify we received all 3 stream events + expect(events).toHaveLength(3) + }) + + it('invoke() also fully consumes multi-yield generator', async () => { + const yields: string[] = [] + const streamingTool = { + name: 'streamer', + description: 'A tool that yields progress events', + toolSpec: { + name: 'streamer', + description: 'A tool that yields progress events', + inputSchema: { type: 'object' as const, properties: {} }, + }, + async *stream(): AsyncGenerator { + yields.push('first') + yield new ToolStreamEvent({ data: 'step 1' }) + yields.push('second') + yield new ToolStreamEvent({ data: 'step 2' }) + yields.push('third') + yield new ToolStreamEvent({ data: 'step 3' }) + return new ToolResultBlock({ + toolUseId: 'stream-id', + status: 'success', + content: [new TextBlock('complete')], + }) + }, + } + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [streamingTool] }) + + const result = await agent.tool.streamer!.invoke() + + expect(result.status).toBe('success') + expect(result.content[0]).toStrictEqual(new TextBlock('complete')) + // Verify all yields were consumed even when using .invoke() + expect(yields).toStrictEqual(['first', 'second', 'third']) + }) + }) + + describe('tool input passthrough', () => { + it('passes ALL parameters to tool execution', async () => { + let receivedInput: unknown = null + const tool = createMockTool( + 'capture-tool', + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('captured')], + }) + ) + // Override stream to capture input + const originalStream = tool.stream.bind(tool) + tool.stream = function (context: ToolContext) { + receivedInput = context.toolUse.input + return originalStream(context) + } + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [tool] }) + + await agent.tool.capture_tool!.invoke({ allowed: 'yes', extra: 'should-pass-through' }) + + // Tool receives ALL parameters + expect(receivedInput).toStrictEqual({ allowed: 'yes', extra: 'should-pass-through' }) + }) + }) + + describe('dynamically added tools', () => { + it('can call a tool that was added after agent creation', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [] }) + + // Add tool after creation + const laterTool = createMockTool( + 'later-tool', + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('dynamic')], + }) + ) + agent.toolRegistry.add(laterTool) + + const result = await agent.tool.later_tool!.invoke() + + expect(result.status).toBe('success') + expect(result.content[0]).toStrictEqual(new TextBlock('dynamic')) + }) + }) +}) + +describe('MessageAddedEvent hooks', () => { + it('fires MessageAddedEvent for each message recorded during direct tool call', async () => { + const { MessageAddedEvent } = await import('../../hooks/events.js') + + const tool = createMockTool( + 'calculator', + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('8')], + }) + ) + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [tool] }) + + const firedEvents: InstanceType[] = [] + agent.addHook(MessageAddedEvent, (event) => { + firedEvents.push(event) + }) + + await agent.tool.calculator!.invoke({ a: 5, b: 3 }) + + // Should fire 3 MessageAddedEvents (one per recorded message). + // Use full-object assertions per TESTING.md. + expect(firedEvents).toHaveLength(3) + expect(firedEvents[0]!.message).toEqual( + new Message({ + role: 'assistant', + content: [ + new ToolUseBlock({ + toolUseId: expect.stringMatching(/^tooluse_/) as unknown as string, + name: 'calculator', + input: { a: 5, b: 3 }, + }), + ], + }) + ) + expect(firedEvents[1]!.message).toEqual( + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('8')], + }), + ], + }) + ) + expect(firedEvents[2]!.message).toEqual( + new Message({ + role: 'assistant', + content: [new TextBlock('agent.tool.calculator was called.')], + }) + ) + }) + + it('does not fire MessageAddedEvent when recordDirectToolCall is false', async () => { + const { MessageAddedEvent } = await import('../../hooks/events.js') + + const tool = createMockTool( + 'calculator', + () => + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('8')], + }) + ) + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, tools: [tool] }) + + const firedEvents: InstanceType[] = [] + agent.addHook(MessageAddedEvent, (event) => { + firedEvents.push(event) + }) + + await agent.tool.calculator!.invoke({ a: 5, b: 3 }, { recordDirectToolCall: false }) + + // No events should fire when recording is disabled + expect(firedEvents).toHaveLength(0) + }) +}) diff --git a/strands-ts/src/agent/agent-as-tool.ts b/strands-ts/src/agent/agent-as-tool.ts new file mode 100644 index 0000000000..c0bcb107b3 --- /dev/null +++ b/strands-ts/src/agent/agent-as-tool.ts @@ -0,0 +1,201 @@ +/** + * Agent-as-tool adapter. + * + * This module provides the AgentAsTool class that wraps an Agent as a tool + * so it can be used by another agent. Agents passed directly in the tools + * array are automatically wrapped via {@link Agent.asTool}. + */ + +import type { Agent } from './agent.js' +import type { Snapshot } from '../types/snapshot.js' +import type { JSONValue } from '../types/json.js' +import { JsonBlock, TextBlock, ToolResultBlock } from '../types/messages.js' +import { createErrorResult, Tool, ToolStreamEvent } from '../tools/tool.js' +import type { ToolContext, ToolStreamGenerator } from '../tools/tool.js' +import type { ToolSpec } from '../tools/types.js' + +/** + * Options for creating an agent tool via {@link Agent.asTool}. + */ +export interface AgentAsToolOptions { + /** + * Tool name exposed to the parent agent's model. + * Must match the pattern `[a-zA-Z0-9_-]{1,64}`. + * + * Defaults to the agent's name. Throws if the resolved name is not a valid + * tool name — provide an explicit name option to override. + */ + name?: string + + /** + * Tool description exposed to the parent agent's model. + * Helps the model understand when to use this tool. + * + * Defaults to the agent's description, or a generic description if the + * agent has no description set. + */ + description?: string + + /** + * Whether to preserve the agent's conversation history across invocations. + * + * When `false` (default), the agent's messages and state are reset to the + * values they had at the time the tool was created, ensuring every call + * starts from the same baseline. + * + * When `true`, the agent retains its conversation history across invocations, + * allowing it to build context over multiple calls. + * + * @defaultValue false + */ + preserveContext?: boolean +} + +/** + * Configuration for creating an AgentAsTool. + */ +interface AgentToolConfig extends AgentAsToolOptions { + agent: Agent +} + +/** + * @internal Not for external use. Use {@link Agent.asTool} to create instances. + * + * Adapter that exposes an Agent as a tool for use by other agents. + * + * The tool accepts a single `input` string parameter, invokes the wrapped + * agent, and returns the text response. + * + * @example + * ```typescript + * import { Agent } from '@strands-agents/sdk' + * + * const researcher = new Agent({ + * name: 'researcher', + * description: 'Finds information on a topic', + * printer: false, + * }) + * + * // Use via convenience method (default: fresh conversation each call) + * const tool = researcher.asTool() + * + * // Preserve context across invocations + * const tool = researcher.asTool({ preserveContext: true }) + * + * const writer = new Agent({ tools: [tool] }) + * const result = await writer.invoke('Write about AI agents') + * ``` + */ +export class AgentAsTool extends Tool { + readonly name: string + readonly description: string + readonly toolSpec: ToolSpec + + private readonly _agent: Agent + private readonly _preserveContext: boolean + private readonly _initialSnapshot: Snapshot | undefined + private _busy = false + + constructor(config: AgentToolConfig) { + super() + this._agent = config.agent + this._preserveContext = config.preserveContext ?? false + + if (!this._preserveContext && this._agent.sessionManager != null) { + throw new Error( + `Agent '${this._agent.name}' has a SessionManager, which conflicts with preserveContext=false. ` + + 'The SessionManager persists conversation history externally, but preserveContext=false resets ' + + 'state between invocations. Use preserveContext=true or remove the SessionManager.' + ) + } + + if (!this._preserveContext) { + this._initialSnapshot = this._agent.takeSnapshot({ preset: 'session' }) + } + + this.name = config.name ?? config.agent.name + + this.description = + config.description ?? + config.agent.description ?? + `Use the ${this.name} agent by providing a natural language input` + + this.toolSpec = { + name: this.name, + description: this.description, + inputSchema: { + type: 'object', + properties: { + input: { + type: 'string', + description: 'The natural language input to send to the agent.', + }, + }, + required: ['input'], + }, + } + } + + /** + * The wrapped agent instance. + */ + get agent(): Agent { + return this._agent + } + + async *stream(toolContext: ToolContext): ToolStreamGenerator { + const { toolUse } = toolContext + const toolUseId = toolUse.toolUseId + + // Concurrency guard: loadSnapshot + agent.stream() must not overlap. + if (this._busy) { + return createErrorResult(`Agent '${this.name}' is already processing a request`, toolUseId) + } + + this._busy = true + try { + const { input } = toolUse.input as { input: string } + + // Reset agent state if not preserving context + if (this._initialSnapshot) { + this._agent.loadSnapshot(this._initialSnapshot) + } + + // Stream the sub-agent, forwarding the outer invocation's state so + // mutations in the inner agent's hooks/tools are visible to the outer + // agent's downstream callbacks and final AgentResult. + const gen = this._agent.stream(input, { invocationState: toolContext.invocationState }) + let next = await gen.next() + while (!next.done) { + const event = next.value + if (event.type == 'toolStreamUpdateEvent') { + yield event.event + } else { + yield new ToolStreamEvent({ data: next.value }) + } + + next = await gen.next() + } + const result = next.value + + // Build the tool result + if (result.structuredOutput !== undefined) { + return new ToolResultBlock({ + toolUseId, + status: 'success', + content: [new JsonBlock({ json: result.structuredOutput as JSONValue })], + }) + } + + return new ToolResultBlock({ + toolUseId, + status: 'success', + content: [new TextBlock(result.toString())], // toString defined by AgentResult + }) + } catch (error) { + return createErrorResult(error, toolUseId) + } finally { + this._busy = false + } + } +} diff --git a/strands-ts/src/agent/agent.ts b/strands-ts/src/agent/agent.ts new file mode 100644 index 0000000000..41c674ed7a --- /dev/null +++ b/strands-ts/src/agent/agent.ts @@ -0,0 +1,2191 @@ +import { + AgentResult, + type AgentStreamEvent, + type InvocationState, + type InvokableAgent, + type InvokeArgs, + type InvokeOptions, + type LocalAgent, + type localAgentSymbol, +} from '../types/agent.js' +import { BedrockModel } from '../models/bedrock.js' +import { + contentBlockFromData, + type ContentBlock, + type ContentBlockData, + Message, + type MessageData, + type SystemPrompt, + type SystemPromptData, + TextBlock, + ToolResultBlock, + type ToolResultBlockData, + ToolUseBlock, +} from '../types/messages.js' +import type { JSONValue } from '../types/json.js' +import { McpClient } from '../mcp.js' +import { isValidToolName, type Tool, type ToolContext } from '../tools/tool.js' +import type { ToolChoice } from '../tools/types.js' +import { systemPromptFromData } from '../types/messages.js' +import { normalizeError, ConcurrentInvocationError, StructuredOutputError } from '../errors.js' +import { Model } from '../models/model.js' +import type { BaseModelConfig, StreamAggregatedResult, StreamOptions } from '../models/model.js' +import { ModelPlugin } from '../plugins/model-plugin.js' +import { isModelStreamEvent } from '../models/streaming.js' +import { ToolRegistry } from '../registry/tool-registry.js' +import { StateStore } from '../state-store.js' +import { AgentPrinter, getDefaultAppender, type Printer } from './printer.js' +import type { Plugin } from '../plugins/plugin.js' +import type { InterventionHandler } from '../interventions/handler.js' +import { InterventionRegistry } from '../interventions/registry.js' +import type { LifecycleObserver } from '../types/lifecycle-observer.js' +import { PluginRegistry } from '../plugins/registry.js' +import { SlidingWindowConversationManager } from '../conversation-manager/sliding-window-conversation-manager.js' +import { NullConversationManager } from '../conversation-manager/null-conversation-manager.js' +import { ConversationManager } from '../conversation-manager/conversation-manager.js' +import { HookRegistryImplementation } from '../hooks/registry.js' +import type { HookableEventConstructor, HookCallback, HookCallbackOptions, HookCleanup } from '../hooks/types.js' +import { + InitializedEvent, + AfterInvocationEvent, + AfterModelCallEvent, + AfterToolCallEvent, + AfterToolsEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, + BeforeToolCallEvent, + BeforeToolsEvent, + HookableEvent, + MessageAddedEvent, + ModelStreamUpdateEvent, + ContentBlockEvent, + ModelMessageEvent, + ToolResultEvent, + AgentResultEvent, + ToolStreamUpdateEvent, + InterruptEvent, + type ModelStopData, +} from '../hooks/events.js' +import { StructuredOutputTool, STRUCTURED_OUTPUT_TOOL_NAME } from '../tools/structured-output-tool.js' +import { AgentAsTool } from './agent-as-tool.js' +import type { AgentAsToolOptions } from './agent-as-tool.js' +import { ToolCaller } from './tool-caller.js' +import type { ToolCallerProxy } from './tool-caller.js' + +import type { z } from 'zod' +import { SessionManager } from '../session/session-manager.js' +import { Tracer } from '../telemetry/tracer.js' +import { Meter } from '../telemetry/meter.js' +import type { AttributeValue } from '@opentelemetry/api' +import { logger } from '../logging/logger.js' +import { CancelledError } from '../errors.js' +import { DefaultModelRetryStrategy } from '../retry/default-model-retry-strategy.js' +import type { RetryStrategy } from '../retry/retry-strategy.js' +import { warnOnDuplicateRetryStrategyTypes } from '../retry/retry-strategy.js' +import { InterruptError, InterruptState, interruptFromAgent } from '../interrupt.js' +import type { InterruptParams } from '../types/interrupt.js' +import { isInterruptResponseContent, type InterruptResponseContent } from '../types/interrupt.js' +import { takeSnapshot as takeSnapshotInternal, loadSnapshot as loadSnapshotInternal } from './snapshot.js' +import type { TakeSnapshotOptions } from './snapshot.js' +import type { Snapshot } from '../types/snapshot.js' + +/** + * Recursive type definition for nested tool arrays. + * Allows tools to be organized in nested arrays of any depth. + * + * {@link Agent} instances in the array are automatically wrapped via + * {@link Agent.asTool}, so they can be passed directly without calling + * `.asTool()` explicitly. + */ +export type ToolList = (Tool | McpClient | Agent | ToolList)[] + +/** + * Strategy for executing tool calls that the model emits in a single assistant turn. + * + * - `'concurrent'` (default) — runs all tool calls from a single turn in parallel. Per-tool event + * order (`BeforeToolCallEvent` → `ToolStreamUpdateEvent*` → `AfterToolCallEvent` → + * `ToolResultEvent`) is preserved, while cross-tool events may interleave. + * - `'sequential'` — runs tool calls one at a time + * + * Cancellation works identically in both modes: {@link Agent.cancel} flips + * {@link Agent.cancelSignal} and tools must observe it cooperatively to stop early. + * In concurrent mode, prompt batch-wide cancellation requires every in-flight tool + * to honor the signal. + */ +export type ToolExecutorStrategy = 'sequential' | 'concurrent' + +/** + * Configuration object for creating a new Agent. + */ +export type AgentConfig = { + /** + * The model instance that the agent will use to make decisions. + * Accepts either a Model instance or a string representing a Bedrock model ID. + * When a string is provided, it will be used to create a BedrockModel instance. + * + * @example + * ```typescript + * // Using a string model ID (creates BedrockModel) + * const agent = new Agent({ + * model: 'anthropic.claude-3-5-sonnet-20240620-v1:0' + * }) + * + * // Using an explicit BedrockModel instance with configuration + * const agent = new Agent({ + * model: new BedrockModel({ + * modelId: 'anthropic.claude-3-5-sonnet-20240620-v1:0', + * temperature: 0.7, + * maxTokens: 2048 + * }) + * }) + * ``` + */ + model?: Model | string + /** An initial set of messages to seed the agent's conversation history. */ + messages?: Message[] | MessageData[] + /** + * An initial set of tools to register with the agent. + * Accepts nested arrays of tools at any depth, which will be flattened automatically. + * {@link Agent} instances are automatically wrapped as tools via {@link Agent.asTool}. + */ + tools?: ToolList + /** + * A system prompt which guides model behavior. + */ + systemPrompt?: SystemPrompt | SystemPromptData + /** Optional initial state values for the agent. */ + appState?: Record + /** + * Optional initial model-provider state (e.g., restoring `responseId` from a + * prior session). Typically only set when hydrating from a snapshot. + */ + modelState?: Record + /** + * Enable automatic printing of agent output to console. + * When true, prints text generation, reasoning, and tool usage as they occur. + * Defaults to true. + */ + printer?: boolean + /** + * Conversation manager for handling message history and context overflow. + * Defaults to SlidingWindowConversationManager with windowSize of 40. + */ + conversationManager?: ConversationManager + /** + * Plugins to register with the agent. + */ + plugins?: Plugin[] + /** + * Retry strategy (or strategies) for failed model/tool calls. + * + * - Omitted: a sensible default {@link DefaultModelRetryStrategy} with exponential backoff is used. + * - Single strategy: the given strategy is used. + * - Array of strategies: all are registered, in the given order. Passing two + * instances of the same concrete class logs a warning — they will collide + * on `plugin.name` when the plugin registry initializes. + * - `null` or `[]`: retries are explicitly disabled; failures propagate to the caller. + */ + retryStrategy?: RetryStrategy | RetryStrategy[] | null + /** + * Intervention handlers evaluated in registration order at each lifecycle point. + */ + interventions?: InterventionHandler[] + /** + * Zod schema for structured output validation. + */ + structuredOutputSchema?: z.ZodSchema + /** + * Session manager for saving and restoring agent sessions + */ + sessionManager?: SessionManager + /** + * Custom trace attributes to include in all spans. + * These attributes are merged with standard attributes in telemetry spans. + * Telemetry must be enabled globally via telemetry.setupTracer() for these to take effect. + */ + traceAttributes?: Record + /** + * Optional name for the agent. Defaults to "Strands Agent". + */ + name?: string + /** + * Optional description of what the agent does. + */ + description?: string + /** + * Optional unique identifier for the agent. Defaults to "agent". + */ + id?: string + /** + * Strategy for executing tool calls from a single assistant turn. + * Defaults to `'concurrent'`. See {@link ToolExecutorStrategy} for details. + */ + toolExecutor?: ToolExecutorStrategy +} + +/** Default name assigned to agents when none is provided. */ +const DEFAULT_AGENT_NAME = 'Strands Agent' + +/** Default identifier assigned to agents when none is provided. */ +const DEFAULT_AGENT_ID = 'agent' + +/** Result returned by tool-execution generators, threading the AfterToolsEvent back to the main loop. */ +type ToolsExecutionResult = { message: Message; afterToolsEvent: AfterToolsEvent } + +/** + * Orchestrates the interaction between a model, a set of tools, and MCP clients. + * The Agent is responsible for managing the lifecycle of tools and clients + * and invoking the core decision-making loop. + */ +export class Agent implements LocalAgent, InvokableAgent { + /** @internal */ + declare readonly [localAgentSymbol]: true + + /** + * The conversation history of messages between user and assistant. + */ + public messages: Message[] + /** + * App state storage accessible to tools and application logic. + * State is not passed to the model during inference. + */ + public readonly appState: StateStore + /** + * Runtime state for the model provider. Used by stateful models to persist + * provider-specific data (e.g., response IDs for conversation chaining) + * across invocations. + */ + public readonly modelState: StateStore + private readonly _conversationManager: ConversationManager + + /** + * The model provider used by the agent for inference. + */ + public model: Model + + /** + * The system prompt to pass to the model provider. + */ + public systemPrompt?: SystemPrompt + + /** + * The name of the agent. + */ + public readonly name: string + + /** + * The unique identifier of the agent instance. + */ + public readonly id: string + + /** + * Optional description of what the agent does. + */ + public readonly description?: string + + /** + * The session manager for saving and restoring agent sessions, if configured. + */ + public readonly sessionManager?: SessionManager | undefined + + private readonly _hooksRegistry: HookRegistryImplementation + private readonly _pluginRegistry: PluginRegistry + private readonly _interventionRegistry: InterventionRegistry + private _toolRegistry: ToolRegistry + private _mcpClients: McpClient[] + private _initialized: boolean + private _isInvoking: boolean = false + private _abortController = new AbortController() + private _abortSignal: AbortSignal = this._abortController.signal + private _printer?: Printer + private _structuredOutputSchema?: z.ZodSchema | undefined + /** Tracer instance for creating and managing OpenTelemetry spans. */ + private _tracer: Tracer + /** Meter instance for accumulating loop metrics during invocation. */ + private _meter: Meter + /** Interrupt state for human-in-the-loop workflows. */ + _interruptState: InterruptState + /** Strategy for executing tool calls from a single assistant turn. */ + private readonly _toolExecutor: ToolExecutorStrategy + /** Direct tool caller — created via {@link ToolCaller.create} factory. */ + private readonly _toolCaller: ToolCallerProxy + + /** + * Creates an instance of the Agent. + * @param config - The configuration for the agent. + */ + constructor(config?: AgentConfig) { + // Initialize public fields + this.messages = (config?.messages ?? []).map((msg) => (msg instanceof Message ? msg : Message.fromMessageData(msg))) + this.appState = new StateStore(config?.appState) + this.modelState = new StateStore(config?.modelState) + this.name = config?.name ?? DEFAULT_AGENT_NAME + this.id = config?.id ?? DEFAULT_AGENT_ID + if (config?.description !== undefined) this.description = config.description + this.sessionManager = config?.sessionManager + + if (typeof config?.model === 'string') { + this.model = new BedrockModel({ modelId: config.model }) + } else { + this.model = config?.model ?? new BedrockModel() + } + + // Validate and assign conversation manager + if (this.model.stateful) { + if (config?.conversationManager) { + throw new Error( + 'Cannot use a conversationManager with a stateful model. The model manages conversation state server-side.' + ) + } + this._conversationManager = new NullConversationManager() + } else { + this._conversationManager = + config?.conversationManager ?? new SlidingWindowConversationManager({ windowSize: 40 }) + } + + const { tools, mcpClients } = flattenTools(config?.tools ?? []) + this._toolRegistry = new ToolRegistry(tools) + this._mcpClients = mcpClients + + // Initialize hooks registry + this._hooksRegistry = new HookRegistryImplementation() + + this._interventionRegistry = new InterventionRegistry(config?.interventions ?? [], this._hooksRegistry) + + // `undefined` (omitted) → install the default; `null`/`[]` → explicit opt-out. + const retryStrategies: RetryStrategy[] = + config?.retryStrategy === null + ? [] + : config?.retryStrategy === undefined + ? [new DefaultModelRetryStrategy()] + : Array.isArray(config.retryStrategy) + ? config.retryStrategy + : [config.retryStrategy] + warnOnDuplicateRetryStrategyTypes(retryStrategies) + + // Initialize plugin registry with all plugins to be initialized during initialize(). + // Ordering notes: + // - ModelPlugin is registered last so that on AfterInvocationEvent (which uses + // reverse callback ordering), it runs first — clearing messages before + // SessionManager saves. + // - Retry-strategy ordering is not load-bearing for correctness: `DefaultModelRetryStrategy` + // guards on `event.retry`, so a user hook that already set it short-circuits + // the strategy regardless of registration order. + this._pluginRegistry = new PluginRegistry([ + this._conversationManager, + ...retryStrategies, + ...(config?.plugins ?? []), + ...(config?.sessionManager ? [config.sessionManager] : []), + new ModelPlugin(this.model), + ]) + + if (config?.systemPrompt !== undefined) { + this.systemPrompt = systemPromptFromData(config.systemPrompt) + } + + // Create printer if printer is enabled (default: true) + const printer = config?.printer ?? true + if (printer) { + this._printer = new AgentPrinter(getDefaultAppender()) + } + + // Store structured output schema + this._structuredOutputSchema = config?.structuredOutputSchema + + // Initialize tracer - OTEL returns no-op tracer if not configured + this._tracer = new Tracer(config?.traceAttributes) + + // Initialize meter for local metrics accumulation + this._meter = new Meter() + + // Initialize interrupt state for human-in-the-loop workflows + this._interruptState = new InterruptState() + + this._toolExecutor = config?.toolExecutor ?? 'concurrent' + // Pass a private helper into ToolCaller so message append + hook firing + // remains an internal concern of Agent (not exposed as a public method). + this._toolCaller = ToolCaller.create(this, (message, invocationState) => + this._appendMessageAndFireHooks(message, invocationState) + ) + + this._initialized = false + } + + /** + * Register a hook callback for a specific event type. + * + * @param eventType - The event class constructor to register the callback for + * @param callback - The callback function to invoke when the event occurs + * @param options - Optional configuration including execution order + * @returns Cleanup function that removes the callback when invoked + * + * @example + * ```typescript + * const agent = new Agent({ model }) + * + * const cleanup = agent.addHook(BeforeInvocationEvent, (event) => { + * console.log('Invocation started') + * }) + * + * // Later, to remove the hook: + * cleanup() + * ``` + */ + addHook( + eventType: HookableEventConstructor, + callback: HookCallback, + options?: HookCallbackOptions + ): HookCleanup { + return this._hooksRegistry.addCallback(eventType, callback, options) + } + + public async initialize(): Promise { + if (this._initialized) { + return + } + + // Initialize MCP clients and register their tools + await Promise.all( + this._mcpClients.map(async (client) => { + const tools = await client.listTools() + this._toolRegistry.add(tools) + client.onToolsChanged = (oldTools, newTools): void => { + oldTools.forEach((name) => this._toolRegistry.remove(name)) + this._toolRegistry.addOrReplace(newTools) + } + }) + ) + + await this._pluginRegistry.initialize(this) + + for (const handler of this._interventionRegistry.handlers) { + const observer = handler as Partial + if (typeof observer.observeAgent === 'function') { + await observer.observeAgent(this) + } + } + + await this._hooksRegistry.invokeCallbacks(new InitializedEvent({ agent: this })) + + this._initialized = true + } + + /** + * Acquires the invocation lock. Throws if an invocation is already in progress. + * Callers must release via try/finally with `this._isInvoking = false`. + */ + private acquireLock(): void { + if (this._isInvoking) { + throw new ConcurrentInvocationError( + 'Agent is already processing an invocation. Wait for the current invoke() or stream() call to complete before invoking again.' + ) + } + this._isInvoking = true + } + + /** + * Throws {@link CancelledError} if cancellation has been requested. + * Called at cancellation checkpoints within the agent loop. + */ + private _throwIfCancelled(): void { + if (this.isCancelled) { + throw new CancelledError() + } + } + + /** + * The tools this agent can use. + */ + get tools(): Tool[] { + return this._toolRegistry.list() + } + + /** + * The tool registry for managing the agent's tools. + */ + get toolRegistry(): ToolRegistry { + return this._toolRegistry + } + + /** + * Whether the agent is currently processing an invocation. + */ + get isInvoking(): boolean { + return this._isInvoking + } + + /** + * Direct tool calling accessor. + * + * Returns a proxy where each property is a {@link ToolHandle} with + * `.invoke()` and `.stream()` methods: + * ```typescript + * const result = await agent.tool.calculator!.invoke({ a: 5, b: 3 }) + * + * for await (const event of agent.tool.calculator!.stream({ a: 5, b: 3 })) { + * console.log('progress:', event) + * } + * ``` + * + * Supports underscore-to-hyphen and case-insensitive name resolution. + * Results are recorded in message history by default (pass + * `{ recordDirectToolCall: false }` to skip). + */ + get tool(): ToolCallerProxy { + return this._toolCaller + } + + /** + * The cancellation signal for the current invocation. + * + * Tools can pass this to cancellable operations (e.g., `fetch(url, { signal: agent.cancelSignal })`). + * Hooks can check `event.agent.cancelSignal.aborted` to detect cancellation. + */ + get cancelSignal(): AbortSignal { + return this._abortSignal + } + + /** + * Cancels the current agent invocation cooperatively. + * + * The agent will stop at the next cancellation checkpoint: + * - During model response streaming + * - Before tool execution + * - Between sequential tool executions + * - At the top of each agent loop cycle + * + * If a tool is already executing, it will run to completion unless + * the tool checks {@link LocalAgent.cancelSignal | cancelSignal} internally. + * + * Hook callbacks can check `event.agent.cancelSignal.aborted` to detect + * cancellation and adjust their behavior accordingly. + * + * The stream/invoke call will return an AgentResult with `stopReason: 'cancelled'`. + * If the agent is not currently invoking, this is a no-op. + * + * @example + * ```typescript + * const agent = new Agent({ model, tools }) + * + * // Cancel after 5 seconds + * setTimeout(() => agent.cancel(), 5000) + * const result = await agent.invoke('Do something') + * console.log(result.stopReason) // 'cancelled' + * ``` + */ + public cancel(): void { + if (this._isInvoking) { + this._abortController.abort() + } + } + + /** + * Whether the current invocation has been cancelled. + * Returns `false` when the agent is idle. + */ + private get isCancelled(): boolean { + return this._abortSignal.aborted + } + + /** + * Invokes the agent and returns the final result. + * + * This is a convenience method that consumes the stream() method and returns + * only the final AgentResult. Use stream() if you need access to intermediate + * streaming events. + * + * @param args - Arguments for invoking the agent + * @param options - Optional per-invocation options + * @returns Promise that resolves to the final AgentResult + * + * @example + * ```typescript + * const agent = new Agent({ model, tools }) + * const result = await agent.invoke('What is 2 + 2?') + * console.log(result.lastMessage) // Agent's response + * ``` + */ + public async invoke(args: InvokeArgs, options?: InvokeOptions): Promise { + const gen = this.stream(args, options) + let result = await gen.next() + while (!result.done) { + result = await gen.next() + } + return result.value + } + + /** + * Streams the agent execution, yielding events and returning the final result. + * + * The agent loop manages the conversation flow by: + * 1. Streaming model responses and yielding all events + * 2. Executing tools when the model requests them + * 3. Continuing the loop until the model completes without tool use + * + * Use this method when you need access to intermediate streaming events. + * For simple request/response without streaming, use invoke() instead. + * + * An explicit goal of this method is to always leave the message array in a way that + * the agent can be reinvoked with a user prompt after this method completes. To that end + * assistant messages containing tool uses are only added after tool execution succeeds + * with valid toolResponses + * + * @param args - Arguments for invoking the agent + * @param options - Optional per-invocation options + * @returns Async generator that yields AgentStreamEvent objects and returns AgentResult + * + * @example + * ```typescript + * const agent = new Agent({ model, tools }) + * + * for await (const event of agent.stream('Hello')) { + * console.log('Event:', event.type) + * } + * // Messages array is mutated in place and contains the full conversation + * ``` + */ + public async *stream( + args: InvokeArgs, + options?: InvokeOptions + ): AsyncGenerator { + this.acquireLock() + try { + await this.initialize() + + let currentArgs: InvokeArgs = args + + // Outer loop: re-enters _stream when a hook sets AfterInvocationEvent.resume. + // One invocation lock spans the whole resume chain. + while (true) { + // Fresh AbortController per invocation iteration, composed with any external signal. + this._abortController = new AbortController() + this._abortSignal = options?.cancelSignal + ? AbortSignal.any([this._abortController.signal, options.cancelSignal]) + : this._abortController.signal + + const streamGenerator = this._stream(currentArgs, options) + let caughtError: Error | undefined + let lastAfterInvocation: AfterInvocationEvent | undefined + let iterationResult: IteratorResult + try { + iterationResult = await streamGenerator.next() + + while (!iterationResult.done) { + try { + const processed = await this._invokeCallbacks(iterationResult.value) + if (processed instanceof AfterInvocationEvent) { + lastAfterInvocation = processed + } + yield processed + iterationResult = await streamGenerator.next() + } catch (error) { + // Throw interrupt errors back into _stream so executeTools can store the + // assistant message as pending execution state for resume. + if (error instanceof InterruptError) { + iterationResult = await streamGenerator.throw(error) + } else { + throw error + } + } + } + + // Suppress AgentResultEvent for resumed iterations — only the final + // invocation in a resume chain reports an agent result. + if (lastAfterInvocation?.resume === undefined) { + yield await this._invokeCallbacks( + new AgentResultEvent({ + agent: this, + result: iterationResult.value, + invocationState: iterationResult.value.invocationState, + }) + ) + } + } catch (error) { + caughtError = error as Error + throw error + } finally { + // Drain _stream() so cleanup hooks and printer still fire. + // Yield only on error (consumer may still be iterating); on a consumer + // break, yielding would suspend the generator and leak the lock. + let drainResult = await streamGenerator.return(undefined as never) + while (!drainResult.done) { + try { + if (caughtError) { + yield await this._invokeCallbacks(drainResult.value) + } else { + await this._invokeCallbacks(drainResult.value) + } + } catch (error) { + logger.warn( + `event_type=<${drainResult.value.type}>, error=<${error}> | error invoking callbacks during cleanup` + ) + } + drainResult = await streamGenerator.next() + } + + // Reset controller and signal for next iteration / invocation + this._abortController = new AbortController() + this._abortSignal = this._abortController.signal + } + + // Resume only on a clean invocation — errors propagate above. + if (lastAfterInvocation?.resume !== undefined) { + currentArgs = lastAfterInvocation.resume + continue + } + + return iterationResult.value + } + } finally { + this._isInvoking = false + } + } + + /** + * Returns a {@link Tool} that wraps this agent, allowing it to be used + * as a tool by another agent. + * + * The returned tool accepts a single `input` string parameter, invokes + * this agent, and returns the text response as a tool result. + * + * **Note:** You can also pass an Agent directly in another agent's + * {@link AgentConfig.tools | tools} array — it will be wrapped + * automatically via this method. + * + * @param options - Optional configuration for the tool name, description, and context preservation + * @returns A Tool wrapping this agent + * + * @example + * ```typescript + * const researcher = new Agent({ name: 'researcher', description: 'Finds info', printer: false }) + * + * // Explicit wrapping + * const writer = new Agent({ tools: [researcher.asTool()] }) + * + * // Automatic wrapping (equivalent) + * const writer = new Agent({ tools: [researcher] }) + * ``` + */ + public asTool(options?: AgentAsToolOptions): Tool { + return new AgentAsTool({ agent: this, ...options }) + } + + /** + * Captures a point-in-time snapshot of the agent's current state. + * + * Use snapshots to checkpoint agent state for later restoration, enabling + * use cases like undo/redo, branching conversations, and session persistence. + * + * Fields are selected via a preset/include/exclude model: + * 1. Start with preset fields (e.g. `'session'` captures all fields) + * 2. Add any `include` fields + * 3. Remove any `exclude` fields + * + * @param options - Controls which fields to capture and optional app data to store + * @returns A {@link Snapshot} containing the captured agent state + * @throws Error if no fields would be included after applying options + * + * @example + * ```typescript + * // Capture all session-relevant state + * const snapshot = agent.takeSnapshot({ preset: 'session' }) + * + * // Capture only messages and state + * const partial = agent.takeSnapshot({ include: ['messages', 'state'] }) + * + * // Capture session state but exclude interrupts + * const noInterrupts = agent.takeSnapshot({ preset: 'session', exclude: ['interrupts'] }) + * + * // Attach application-owned metadata + * const withMeta = agent.takeSnapshot({ preset: 'session', appData: { userId: 'u-123' } }) + * ``` + */ + public takeSnapshot(options: TakeSnapshotOptions): Snapshot { + return takeSnapshotInternal(this, options) + } + + /** + * Restores agent state from a previously captured snapshot. + * + * Only fields present in `snapshot.data` are restored; absent fields are left + * unchanged. This allows partial snapshots to update specific aspects of state + * without affecting others. + * + * @param snapshot - The snapshot to restore from + * @throws Error if `snapshot.schemaVersion` is incompatible or scope is wrong + * + * @example + * ```typescript + * // Save and restore a conversation checkpoint + * const checkpoint = agent.takeSnapshot({ preset: 'session' }) + * + * // ... agent continues processing ... + * + * // Restore to the checkpoint + * agent.loadSnapshot(checkpoint) + * + * // Restore from a JSON-serialized snapshot (e.g. from storage) + * const stored = JSON.parse(savedSnapshotJson) + * agent.loadSnapshot(stored) + * ``` + */ + public loadSnapshot(snapshot: Snapshot): void { + loadSnapshotInternal(this, snapshot) + } + + /** + * Invokes hook callbacks and printer for a stream event. + * + * @param event - The event to process + * @returns The event after processing + */ + private async _invokeCallbacks(event: AgentStreamEvent): Promise { + if (event instanceof HookableEvent) { + await this._hooksRegistry.invokeCallbacks(event) + } + this._printer?.processEvent(event) + return event + } + + /** + * Internal implementation of the agent streaming logic. + * Separated to centralize printer event processing in the public stream method. + * + * @param args - Arguments for invoking the agent + * @param options - Optional per-invocation options + * @returns Async generator that yields AgentStreamEvent objects and returns AgentResult + */ + private async *_stream( + args: InvokeArgs, + options?: InvokeOptions + ): AsyncGenerator { + let currentArgs: InvokeArgs | undefined = args + let result: AgentResult | undefined + + // Resolve structured output schema from per-invocation options or constructor config + const structuredOutputSchema = options?.structuredOutputSchema ?? this._structuredOutputSchema + const structuredOutputTool = structuredOutputSchema ? new StructuredOutputTool(structuredOutputSchema) : undefined + let structuredOutputChoice: ToolChoice | undefined + + // Resolve per-invocation state once. The same object is threaded through + // every lifecycle hook event, every tool context, and is surfaced on the + // AgentResult. Mutations by hooks/tools are visible across all recursive + // agent loop cycles within this invocation. + const invocationState: InvocationState = options?.invocationState ?? {} + + // Handle interrupt responses if present in input + const interruptResponses = this._extractInterruptResponses(args) + if (interruptResponses.length > 0) { + this._interruptState.resume(interruptResponses) + } + + // Reject non-interrupt input while in interrupted state + if (this._interruptState.activated && interruptResponses.length === 0) { + throw new TypeError('Agent is in an interrupted state. Resume by invoking with interruptResponse content blocks.') + } + + const beforeInvocationEvent = new BeforeInvocationEvent({ agent: this, invocationState }) + yield beforeInvocationEvent + + if (beforeInvocationEvent.cancel) { + const cancelText = + typeof beforeInvocationEvent.cancel === 'string' ? beforeInvocationEvent.cancel : 'invocation denied by hook' + const message = new Message({ role: 'assistant', content: [new TextBlock(cancelText)] }) + yield this._appendMessage(message, invocationState) + yield new AfterInvocationEvent({ agent: this, invocationState }) + return new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + traces: this._tracer.localTraces, + metrics: this._meter.metrics, + invocationState, + }) + } + + // Normalize input to get the user messages for telemetry + const inputMessages = this._normalizeInput(args) + + // Start agent trace span + this._meter.startNewInvocation() + const agentModelId = this.model.modelId + const agentSpanOptions: Parameters[0] = { + messages: inputMessages, + agentName: this.name, + agentId: this.id, + tools: this.tools, + } + if (agentModelId) agentSpanOptions.modelId = agentModelId + if (this.systemPrompt !== undefined) agentSpanOptions.systemPrompt = this.systemPrompt + const agentSpan = this._tracer.startAgentSpan(agentSpanOptions) + + let caughtError: Error | undefined + try { + // Register structured output tool if schema provided + if (structuredOutputTool) { + this._toolRegistry.add(structuredOutputTool) + } + + // Main agent loop - continues until model stops without requesting tools + while (true) { + this._throwIfCancelled() + + // Start metrics cycle tracking + const { cycleId, startTime: cycleStartTime } = this._meter.startCycle() + + // Create agent loop cycle span within agent span context + const cycleSpan = this._tracer.startAgentLoopSpan({ + cycleId, + messages: this.messages, + }) + + try { + // Normalize input and append user messages on first invocation only + if (currentArgs !== undefined) { + const messagesToAppend = this._normalizeInput(currentArgs) + for (const message of messagesToAppend) { + yield this._appendMessage(message, invocationState) + } + currentArgs = undefined + } + + // Check if we're resuming from a tool interrupt + const pendingExecution = this._interruptState.getPendingExecution() + let assistantMessage: Message + let completedToolResults: Map | undefined + + if (pendingExecution) { + // Resume from stored state - skip model call + assistantMessage = pendingExecution.assistantMessage + completedToolResults = pendingExecution.completedToolResults + this._interruptState.clearPendingToolExecution() + } else { + const modelResult = yield* this._invokeModel(invocationState, structuredOutputChoice) + + if (modelResult.stopReason !== 'toolUse') { + // Schema set, we already forced, and the model still refused. + // Throw before closing the span so the cycle span records the error. + if (structuredOutputTool && structuredOutputChoice) { + throw new StructuredOutputError( + 'The model failed to invoke the structured output tool even after it was forced.' + ) + } + + this._meter.endCycle(cycleStartTime) + this._tracer.endAgentLoopSpan(cycleSpan) + + // Schema set, model ignored the tool — drop the response and force the tool next cycle. + // Appending the plain-text turn here would leave the conversation ending on an + // assistant message, which providers like Bedrock reject as assistant prefill. + if (structuredOutputTool) { + structuredOutputChoice = { tool: { name: STRUCTURED_OUTPUT_TOOL_NAME } } + logger.debug( + 'structured output schema set but model responded with plain text; forcing tool use on next cycle' + ) + continue + } + + // Normal end of turn. + yield this._appendMessage(modelResult.message, invocationState) + result = new AgentResult({ + stopReason: modelResult.stopReason, + lastMessage: modelResult.message, + traces: this._tracer.localTraces, + metrics: this._meter.metrics, + invocationState, + }) + return result + } + + // Cancel before tool execution: create error results for all pending tools + if (this.isCancelled) { + const toolUseBlocks = modelResult.message.content.filter( + (block): block is ToolUseBlock => block.type === 'toolUseBlock' + ) + const cancelBlocks = toolUseBlocks.map( + (block) => + new ToolResultBlock({ + toolUseId: block.toolUseId, + status: 'error', + content: [new TextBlock('Tool execution cancelled')], + }) + ) + const toolResultMessage = new Message({ role: 'user', content: cancelBlocks }) + + yield this._appendMessage(modelResult.message, invocationState) + yield this._appendMessage(toolResultMessage, invocationState) + + this._meter.endCycle(cycleStartTime) + this._tracer.endAgentLoopSpan(cycleSpan) + + result = new AgentResult({ + stopReason: 'cancelled', + lastMessage: modelResult.message, + traces: this._tracer.localTraces, + metrics: this._meter.metrics, + invocationState, + }) + return result + } + + assistantMessage = modelResult.message + } + + // Execute tools + const toolsResult = yield* this.executeTools( + assistantMessage, + this._toolRegistry, + invocationState, + completedToolResults + ) + + // When the consumer breaks the stream (e.g. agent.cancel() + break), + // yield* returns undefined because the inner generator was closed. + if (!toolsResult) { + this._meter.endCycle(cycleStartTime) + this._tracer.endAgentLoopSpan(cycleSpan) + continue + } + const toolResultMessage = toolsResult.message + + /** + * Deferred append: both messages are added AFTER tool execution completes. + * This keeps agent.messages in a valid, reinvokable state at all times. + * If interrupted during tool execution, messages has no dangling toolUse + * without a matching toolResult, so the agent can be reinvoked cleanly. + */ + yield this._appendMessage(assistantMessage, invocationState) + yield this._appendMessage(toolResultMessage, invocationState) + + // Deactivate interrupt state after successful tool execution so the next + // cycle starts with a clean slate (new interrupts can be raised again). + if (this._interruptState.activated) { + this._interruptState.deactivate() + } + + this._meter.endCycle(cycleStartTime) + this._tracer.endAgentLoopSpan(cycleSpan) + + // Hook requested halt: exit without calling the model again + const { afterToolsEvent } = toolsResult + if (afterToolsEvent.endTurn) { + const endTurnText = + typeof afterToolsEvent.endTurn === 'string' + ? afterToolsEvent.endTurn + : 'Turn ended early by hook after tool execution' + const lastMessage = new Message({ role: 'assistant', content: [new TextBlock(endTurnText)] }) + yield this._appendMessage(lastMessage, invocationState) + + result = new AgentResult({ + stopReason: 'endTurn', + lastMessage, + traces: this._tracer.localTraces, + metrics: this._meter.metrics, + invocationState, + }) + return result + } + + // Structured output captured: exit + const structuredOutput = structuredOutputTool + ? this._extractStructuredOutput(assistantMessage, toolResultMessage) + : undefined + if (structuredOutput !== undefined) { + result = new AgentResult({ + stopReason: 'toolUse', + lastMessage: assistantMessage, + traces: this._tracer.localTraces, + structuredOutput, + metrics: this._meter.metrics, + invocationState, + }) + return result + } + } catch (error) { + this._meter.endCycle(cycleStartTime) + this._tracer.endAgentLoopSpan(cycleSpan, { error: error as Error }) + throw error + } + } + } catch (error) { + if (error instanceof CancelledError) { + // Cancelled during model streaming or at the top of a cycle. + // No partial messages have been appended (deferred append pattern). + const cancelMessage = new Message({ + role: 'assistant', + content: [new TextBlock('Cancelled by user')], + }) + yield this._appendMessage(cancelMessage, invocationState) + + result = new AgentResult({ + stopReason: 'cancelled', + lastMessage: cancelMessage, + traces: this._tracer.localTraces, + metrics: this._meter.metrics, + invocationState, + }) + return result + } + if (error instanceof InterruptError) { + // Fan out one event per interrupt. Each event exposes `interrupt.source` so + // consumers can filter by origin (tool callback vs hook callback) without + // subscribing to separate event types. + for (const interrupt of error.interrupts) { + yield new InterruptEvent({ agent: this, interrupt, invocationState }) + } + result = this._createInterruptResult(invocationState) + return result + } + caughtError = error as Error + throw error + } finally { + // If cancelled but the catch block was bypassed (generator terminated + // via .return() when the consumer breaks out of for-await), append an + // assistant message so the agent can be reinvoked with a new user prompt. + if (!caughtError && !result && this.isCancelled) { + const cancelMessage = new Message({ + role: 'assistant', + content: [new TextBlock('Cancelled by user')], + }) + yield this._appendMessage(cancelMessage, invocationState) + } + + this._tracer.endAgentSpan(agentSpan, { + ...(caughtError && { error: caughtError }), + ...(result?.lastMessage && { response: result.lastMessage }), + accumulatedUsage: this._meter.metrics.accumulatedUsage, + ...(result?.stopReason && { stopReason: result.stopReason }), + }) + + // Cleanup structured output tool + if (structuredOutputTool) { + this._toolRegistry.remove(STRUCTURED_OUTPUT_TOOL_NAME) + } + + // Always emit final event + yield new AfterInvocationEvent({ agent: this, invocationState }) + } + } + + /** + * Extracts the validated structured output result from tool execution. + * + * @param toolUseMessage - The assistant message containing tool use blocks + * @param toolResultMessage - The message containing tool results + * @returns The parsed structured output, or undefined if not found + */ + private _extractStructuredOutput(toolUseMessage: Message, toolResultMessage: Message): unknown | undefined { + const toolUse = toolUseMessage.content.find( + (block): block is ToolUseBlock => block.type === 'toolUseBlock' && block.name === STRUCTURED_OUTPUT_TOOL_NAME + ) + if (!toolUse) return undefined + + const toolResult = toolResultMessage.content.find( + (block): block is ToolResultBlock => + block.type === 'toolResultBlock' && block.toolUseId === toolUse.toolUseId && block.status === 'success' + ) + if (!toolResult) return undefined + + const firstContent = toolResult.content[0] + return firstContent?.type === 'jsonBlock' ? firstContent.json : undefined + } + + /** + * Creates an AgentResult for an interrupt stop. + * + * @param invocationState - The current invocation state + * @returns AgentResult with stopReason 'interrupt' + */ + private _createInterruptResult(invocationState: InvocationState): AgentResult { + this._interruptState.activate() + return new AgentResult({ + stopReason: 'interrupt', + lastMessage: + this.messages.length > 0 + ? this.messages[this.messages.length - 1]! + : new Message({ role: 'assistant', content: [new TextBlock('Interrupted')] }), + traces: this._tracer.localTraces, + metrics: this._meter.metrics, + interrupts: this._interruptState.getUnansweredInterrupts(), + invocationState, + }) + } + + /** + * Extracts interrupt response content blocks from invocation args. + * + * @param args - The invocation arguments + * @returns Array of InterruptResponseContent blocks, empty if none found + * @throws TypeError if args mix interrupt responses with other content + */ + private _extractInterruptResponses(args: InvokeArgs): InterruptResponseContent[] { + if (!Array.isArray(args) || args.length === 0) { + return [] + } + + const responses: InterruptResponseContent[] = [] + let hasNonInterrupt = false + + for (const item of args) { + if (isInterruptResponseContent(item)) { + responses.push(item) + } else { + hasNonInterrupt = true + } + } + + if (responses.length > 0 && hasNonInterrupt) { + throw new TypeError('Must resume from interrupt with a list of interruptResponse content blocks only') + } + + return responses + } + + /** + * Normalizes agent invocation input into an array of messages to append. + * + * @param args - Optional arguments for invoking the model + * @returns Array of messages to append to the conversation + */ + private _normalizeInput(args?: InvokeArgs): Message[] { + if (args !== undefined) { + if (typeof args === 'string') { + // String input: wrap in TextBlock and create user Message + return [ + new Message({ + role: 'user', + content: [new TextBlock(args)], + }), + ] + } else if (Array.isArray(args) && args.length > 0) { + const firstElement = args[0]! + + // Check if it's interrupt responses - skip creating messages for these + if (isInterruptResponseContent(firstElement)) { + // Pure interrupt responses: no messages to add + return [] + } + + // Check if it's Message[] or MessageData[] + if ('role' in firstElement && typeof firstElement.role === 'string') { + // Check if it's a Message instance or MessageData + if (firstElement instanceof Message) { + // Message[] input: return all messages + return args as Message[] + } else { + // MessageData[] input: convert to Message[] + return (args as MessageData[]).map((data) => Message.fromMessageData(data)) + } + } else { + // It's ContentBlock[] or ContentBlockData[] + // Check if it's ContentBlock instances or ContentBlockData + let contentBlocks: ContentBlock[] + if ('type' in firstElement && typeof firstElement.type === 'string') { + // ContentBlock[] input: use as-is + contentBlocks = args as ContentBlock[] + } else { + // ContentBlockData[] input: convert using helper function + contentBlocks = (args as ContentBlockData[]).map(contentBlockFromData) + } + + return [ + new Message({ + role: 'user', + content: contentBlocks, + }), + ] + } + } + } + // undefined or empty array: no messages to append + return [] + } + + /** + * Invokes the model provider and streams all events. + * + * @param args - Optional arguments for invoking the model + * @param toolChoice - Optional tool choice to force specific tool usage + * @returns Object containing the assistant message, stop reason, and optional redaction message + */ + private async *_invokeModel( + invocationState: InvocationState, + toolChoice?: ToolChoice + ): AsyncGenerator { + const toolSpecs = this._toolRegistry.list().map((tool) => tool.toolSpec) + const streamOptions: StreamOptions = { toolSpecs, modelState: this.modelState } + if (this.systemPrompt !== undefined) { + streamOptions.systemPrompt = this.systemPrompt + } + + // Add tool choice if provided + if (toolChoice) { + streamOptions.toolChoice = toolChoice + } + + let attemptCount = 1 + while (true) { + // Estimate input tokens for the upcoming model call (non-fatal if estimation fails) + let projectedInputTokens: number | undefined + try { + projectedInputTokens = await this._estimateInputTokens(streamOptions) + } catch (e) { + logger.debug(`error=<${e}> | token estimation failed, proceeding without estimate`) + } + + const beforeModelCallEvent = new BeforeModelCallEvent({ + agent: this, + model: this.model, + invocationState, + ...(projectedInputTokens !== undefined && { projectedInputTokens }), + }) + yield beforeModelCallEvent + + if (beforeModelCallEvent.cancel) { + const cancelText = + typeof beforeModelCallEvent.cancel === 'string' ? beforeModelCallEvent.cancel : 'model call denied by hook' + const message = new Message({ role: 'assistant', content: [new TextBlock(cancelText)] }) + const stopData: ModelStopData = { message, stopReason: 'endTurn' } + const afterModelCallEvent = new AfterModelCallEvent({ + agent: this, + model: this.model, + attemptCount, + stopData, + invocationState, + }) + yield afterModelCallEvent + + if (afterModelCallEvent.retry) { + attemptCount += 1 + continue + } + + return { message, stopReason: 'endTurn' } + } + + // Start model span within loop span context + const modelId = this.model.modelId + const modelSpan = this._tracer.startModelInvokeSpan({ + messages: this.messages, + ...(modelId && { modelId }), + ...(this.systemPrompt !== undefined && { systemPrompt: this.systemPrompt }), + }) + + try { + const result = yield* this._streamFromModel(this.messages, streamOptions, invocationState) + + // Accumulate token usage and model latency metrics + this._meter.updateCycle(result.metadata) + + // End model span with usage + const usage = result.metadata?.usage + const metrics = result.metadata?.metrics + this._tracer.endModelInvokeSpan(modelSpan, { + output: result.message, + stopReason: result.stopReason, + ...(usage && { usage }), + ...(metrics && { metrics }), + }) + + yield new ModelMessageEvent({ + agent: this, + message: result.message, + stopReason: result.stopReason, + invocationState, + }) + + // Handle user content redaction if guardrails blocked input + if (result.redaction?.userMessage) { + this._redactLastMessage(result.redaction.userMessage) + } + + const stopData: ModelStopData = { + message: result.message, + stopReason: result.stopReason, + ...(result.redaction && { redaction: result.redaction }), + } + + const afterModelCallEvent = new AfterModelCallEvent({ + agent: this, + model: this.model, + attemptCount, + stopData, + invocationState, + }) + yield afterModelCallEvent + + if (afterModelCallEvent.retry) { + attemptCount += 1 + continue + } + + return result + } catch (error) { + const modelError = normalizeError(error) + + // End model span with error + this._tracer.endModelInvokeSpan(modelSpan, { error: modelError }) + + // Create error event + const errorEvent = new AfterModelCallEvent({ + agent: this, + model: this.model, + attemptCount, + error: modelError, + invocationState, + }) + + // Yield error event - stream will invoke hooks + yield errorEvent + + // Let CancelledError propagate directly — no retry + // (we emit the AfterModelCall because we already emitted Before and we guarentee the pair) + if (error instanceof CancelledError) { + throw error + } + + // After yielding, hooks have been invoked and may have set retry + if (errorEvent.retry) { + attemptCount += 1 + continue + } + + // Re-throw error + throw error + } + } + } + + /** + * Streams events from the model and dispatches appropriate events for each. + * + * The model's `streamAggregated()` yields two kinds of output: + * - **ModelStreamEvent**: Transient streaming deltas (partial data while generating). + * Wrapped in {@link ModelStreamUpdateEvent} before yielding. + * - **ContentBlock**: Fully assembled results (after all deltas accumulate). + * Wrapped in {@link ContentBlockEvent} before yielding. + * + * These are separate event classes because they represent different granularities + * (partial deltas vs finished blocks). Both are yielded in the stream and hookable. + * + * @param messages - Messages to send to the model + * @param streamOptions - Options for streaming + * @returns StreamAggregatedResult containing message, stop reason, and optional redaction message + */ + private async *_streamFromModel( + messages: Message[], + streamOptions: StreamOptions, + invocationState: InvocationState + ): AsyncGenerator { + messages = normalizeToolUseNames(messages) + const streamGenerator = this.model.streamAggregated(messages, streamOptions) + let result = await streamGenerator.next() + + while (!result.done) { + this._throwIfCancelled() + + const event = result.value + + if (isModelStreamEvent(event)) { + // ModelStreamEvent: wrap in ModelStreamUpdateEvent + yield new ModelStreamUpdateEvent({ agent: this, event, invocationState }) + } else { + // ContentBlock: wrap in ContentBlockEvent + yield new ContentBlockEvent({ agent: this, contentBlock: event, invocationState }) + } + result = await streamGenerator.next() + } + + // result.done is true, result.value contains the return value + return result.value + } + + /** + * Emits `BeforeToolsEvent`, handles the pre-launch cancel paths, then + * delegates per-tool execution to the configured {@link ToolExecutorStrategy}. + * Always pairs `BeforeToolsEvent` with a terminal `AfterToolsEvent`, even on + * the invariant-violation throw path. + * + * @param assistantMessage - The assistant message containing tool use blocks + * @param toolRegistry - Registry containing available tools + * @returns Tool-result message and the dispatched AfterToolsEvent + */ + private async *executeTools( + assistantMessage: Message, + toolRegistry: ToolRegistry, + invocationState: InvocationState, + completedToolResults?: Map + ): AsyncGenerator { + const beforeToolsEvent = new BeforeToolsEvent({ agent: this, message: assistantMessage, invocationState }) + try { + yield beforeToolsEvent + } catch (error) { + // Store pending state before re-throwing so the agent can resume from this point. + // The error must still propagate to _stream which handles the interrupt stop. + if (error instanceof InterruptError) { + this._interruptState.setPendingToolExecution({ + assistantMessageData: assistantMessage.toJSON(), + completedToolResults: {}, + }) + } + throw error + } + + const toolUseBlocks = assistantMessage.content.filter( + (block): block is ToolUseBlock => block.type === 'toolUseBlock' + ) + if (toolUseBlocks.length === 0) { + // Preserve BeforeToolsEvent/AfterToolsEvent bracket symmetry even on + // this invariant-violation branch. + yield new AfterToolsEvent({ + agent: this, + message: new Message({ role: 'user', content: [] }), + invocationState, + }) + throw new Error('Model indicated toolUse but no tool use blocks found in message') + } + + // Pre-launch cancel paths are strategy-independent. + if (beforeToolsEvent.cancel) { + const message = typeof beforeToolsEvent.cancel === 'string' ? beforeToolsEvent.cancel : 'Tool cancelled by hook' + return yield* this._yieldCancelledToolResults(toolUseBlocks, message, invocationState) + } + if (this.isCancelled) { + return yield* this._yieldCancelledToolResults(toolUseBlocks, 'Tool execution cancelled', invocationState) + } + + switch (this._toolExecutor) { + case 'sequential': + return yield* this._executeToolsSequential( + toolUseBlocks, + toolRegistry, + invocationState, + completedToolResults, + assistantMessage + ) + case 'concurrent': + return yield* this._executeToolsConcurrent( + toolUseBlocks, + toolRegistry, + invocationState, + completedToolResults, + assistantMessage + ) + default: { + const _exhaustive: never = this._toolExecutor + throw new Error(`Unknown toolExecutor: ${_exhaustive as string}`) + } + } + } + + /** + * Emits a `ToolResultEvent` for every block plus an `AfterToolsEvent`, and + * returns the resulting tool-result message and dispatched event. Used by the pre-launch cancel + * paths shared across executors. + */ + private async *_yieldCancelledToolResults( + toolUseBlocks: ToolUseBlock[], + message: string, + invocationState: InvocationState + ): AsyncGenerator { + const cancelBlocks = this._cancelAllAsResults(toolUseBlocks, message) + for (const result of cancelBlocks) { + yield new ToolResultEvent({ agent: this, result, invocationState }) + } + const toolResultMessage = new Message({ role: 'user', content: cancelBlocks }) + const afterToolsEvent = new AfterToolsEvent({ agent: this, message: toolResultMessage, invocationState }) + yield afterToolsEvent + return { message: toolResultMessage, afterToolsEvent } + } + + /** + * Executes tools one at a time, honoring `agent.cancelSignal` between + * iterations to short-circuit not-yet-started tools. + */ + private async *_executeToolsSequential( + toolUseBlocks: ToolUseBlock[], + toolRegistry: ToolRegistry, + invocationState: InvocationState, + completedToolResults?: Map, + assistantMessage?: Message + ): AsyncGenerator { + const toolResultBlocks: ToolResultBlock[] = [] + let toolResultMessage: Message + let afterToolsEvent: AfterToolsEvent + + try { + for (const toolUseBlock of toolUseBlocks) { + // Skip tools that were already completed before the interrupt + if (completedToolResults?.has(toolUseBlock.toolUseId)) { + const completedResult = completedToolResults.get(toolUseBlock.toolUseId)! + // No events emitted for already-completed tools. + // The result is included in the final tool result message. + toolResultBlocks.push(completedResult) + continue + } + + if (this.isCancelled) { + const cancelBlock = new ToolResultBlock({ + toolUseId: toolUseBlock.toolUseId, + status: 'error', + content: [new TextBlock('Tool execution cancelled')], + }) + toolResultBlocks.push(cancelBlock) + yield new ToolResultEvent({ agent: this, result: cancelBlock, invocationState }) + continue + } + + try { + const toolResultBlock = yield* this.executeTool(toolUseBlock, toolRegistry, invocationState) + toolResultBlocks.push(toolResultBlock) + yield new ToolResultEvent({ agent: this, result: toolResultBlock, invocationState }) + } catch (error) { + if (error instanceof InterruptError) { + // Store pending state with completed results so far + const completedSoFar: Record = {} + for (const block of toolResultBlocks) { + completedSoFar[block.toolUseId] = block.toJSON() + } + // Also include any previously completed results + if (completedToolResults) { + for (const [id, block] of completedToolResults) { + completedSoFar[id] = block.toJSON() + } + } + this._interruptState.setPendingToolExecution({ + assistantMessageData: assistantMessage!.toJSON(), + completedToolResults: completedSoFar, + }) + throw error + } + throw error + } + } + } finally { + toolResultMessage = new Message({ role: 'user', content: toolResultBlocks }) + afterToolsEvent = new AfterToolsEvent({ agent: this, message: toolResultMessage, invocationState }) + yield afterToolsEvent + } + + return { message: toolResultMessage, afterToolsEvent } + } + + /** + * Produces one error ToolResultBlock per tool use block, each carrying + * `message` as its error text. Shared by pre-launch cancel paths. + */ + private _cancelAllAsResults(toolUseBlocks: ToolUseBlock[], message: string): ToolResultBlock[] { + return toolUseBlocks.map( + (block) => + new ToolResultBlock({ + toolUseId: block.toolUseId, + status: 'error', + content: [new TextBlock(message)], + }) + ) + } + + /** + * Executes tools concurrently by merging N per-tool {@link executeTool} + * async generators via `Promise.race`. Per-tool event order is preserved + * (because each generator is iterated serially); cross-tool events may + * interleave at race resolution boundaries. + * + * Per-tool retry (`AfterToolCallEvent.retry`) is isolated — it lives inside + * `executeTool`'s own `while(true)` loop, so one tool retrying does not + * disturb its siblings. + */ + private async *_executeToolsConcurrent( + toolUseBlocks: ToolUseBlock[], + toolRegistry: ToolRegistry, + invocationState: InvocationState, + completedToolResults?: Map, + assistantMessage?: Message + ): AsyncGenerator { + let toolResultMessage: Message + let afterToolsEvent: AfterToolsEvent + + // Wrap each in-flight `.next()` so the raced promise always resolves to a + // tagged Step. That prevents one generator rejection from rejecting the + // whole race and lets us convert per-tool failures into ToolResultBlocks + // without orphaning other generators. + type Step = + | { idx: number; kind: 'next'; res: IteratorResult } + | { idx: number; kind: 'throw'; error: unknown } + + const gens = toolUseBlocks.map((block) => ({ + block, + gen: completedToolResults?.has(block.toolUseId) + ? undefined // Skip already-completed tools + : this.executeTool(block, toolRegistry, invocationState), + })) + + const step = (idx: number): Promise => + gens[idx]!.gen!.next().then( + (res): Step => ({ idx, kind: 'next', res }), + (error: unknown): Step => ({ idx, kind: 'throw', error }) + ) + + // Seed completed results from resume state + const resultsByToolUseId = new Map() + if (completedToolResults) { + for (const [id, result] of completedToolResults) { + resultsByToolUseId.set(id, result) + } + } + + // Only race tools that need execution + const pendingNext = new Map>() + for (let idx = 0; idx < gens.length; idx++) { + if (gens[idx]!.gen) { + pendingNext.set(idx, step(idx)) + } + } + + // Track interrupts — let all other tools finish before propagating + let interruptError: InterruptError | undefined + + try { + while (pendingNext.size > 0) { + const winner = await Promise.race(pendingNext.values()) + const { idx } = winner + const block = gens[idx]!.block + + if (winner.kind === 'throw') { + pendingNext.delete(idx) + + // Detect InterruptError — don't convert to error result, track it + if (winner.error instanceof InterruptError) { + interruptError = winner.error + continue + } + + const err = normalizeError(winner.error) + const result = new ToolResultBlock({ + toolUseId: block.toolUseId, + status: 'error', + content: [new TextBlock(err.message)], + error: err, + }) + resultsByToolUseId.set(block.toolUseId, result) + yield new ToolResultEvent({ agent: this, result, invocationState }) + continue + } + + if (winner.res.done) { + pendingNext.delete(idx) + resultsByToolUseId.set(block.toolUseId, winner.res.value) + yield new ToolResultEvent({ agent: this, result: winner.res.value, invocationState }) + } else { + try { + yield winner.res.value + } catch (e) { + // InterruptError thrown back into generator from stream() error injection + if (e instanceof InterruptError) { + interruptError = e + pendingNext.delete(idx) + continue + } + throw e + } + pendingNext.set(idx, step(idx)) + } + } + + // After all tools finish, propagate interrupt if one was raised + if (interruptError) { + const completedSoFar: Record = {} + for (const [id, result] of resultsByToolUseId) { + completedSoFar[id] = result.toJSON() + } + this._interruptState.setPendingToolExecution({ + assistantMessageData: assistantMessage!.toJSON(), + completedToolResults: completedSoFar, + }) + throw interruptError + } + } finally { + // Close any generators still in-flight (e.g. consumer broke out of stream). + await Promise.allSettled( + Array.from(pendingNext.keys(), (idx) => gens[idx]!.gen!.return(undefined as unknown as ToolResultBlock)) + ) + + // Build the result message from whatever completed, in source order. + // Missing entries get a fallback error block so the message always + // accounts for every toolUseBlock the model emitted. + const toolResultBlocks: ToolResultBlock[] = [] + for (const block of toolUseBlocks) { + const result = resultsByToolUseId.get(block.toolUseId) + if (result) { + toolResultBlocks.push(result) + } else { + toolResultBlocks.push( + new ToolResultBlock({ + toolUseId: block.toolUseId, + status: 'error', + content: [new TextBlock('Tool execution interrupted')], + }) + ) + } + } + + toolResultMessage = new Message({ role: 'user', content: toolResultBlocks }) + afterToolsEvent = new AfterToolsEvent({ agent: this, message: toolResultMessage, invocationState }) + yield afterToolsEvent + } + + return { message: toolResultMessage, afterToolsEvent } + } + + /** + * Executes a single tool and returns the result. + * If the tool is not found or fails to return a result, returns an error ToolResult + * instead of throwing an exception. This allows the agent loop to continue and + * let the model handle the error gracefully. + * + * @param toolUseBlock - Tool use block to execute + * @param toolRegistry - Registry containing available tools + * @returns Tool result block + */ + private async *executeTool( + toolUseBlock: ToolUseBlock, + toolRegistry: ToolRegistry, + invocationState: InvocationState + ): AsyncGenerator { + const registryTool = toolRegistry.get(toolUseBlock.name) + + // Create toolUse object for hook events and telemetry. Callbacks may mutate + // this object's fields (input/name/toolUseId) inside BeforeToolCallEvent. + const toolUse = { + name: toolUseBlock.name, + toolUseId: toolUseBlock.toolUseId, + input: toolUseBlock.input, + } + + // Retry loop for tool execution + while (true) { + const beforeToolCallEvent = new BeforeToolCallEvent({ + agent: this, + toolUse, + tool: registryTool, + invocationState, + }) + yield beforeToolCallEvent + + // Resolve the tool that would actually execute. selectedTool wins; + // otherwise if the hook renamed toolUse.name, re-resolve from the + // registry under the new name; otherwise use the original registry + // lookup. Resolved before the cancel check so AfterToolCallEvent.tool + // is consistent whether the cancel or execution branch runs. + const effectiveTool = + beforeToolCallEvent.selectedTool ?? + (toolUse.name !== toolUseBlock.name ? toolRegistry.get(toolUse.name) : registryTool) + + // Cancel individual tool if hook requested it + if (beforeToolCallEvent.cancel) { + const cancelMessage = + typeof beforeToolCallEvent.cancel === 'string' ? beforeToolCallEvent.cancel : 'Tool cancelled by hook' + const cancelResult = new ToolResultBlock({ + toolUseId: toolUse.toolUseId, + status: 'error', + content: [new TextBlock(cancelMessage)], + }) + const afterToolCallEvent = new AfterToolCallEvent({ + agent: this, + toolUse, + tool: effectiveTool, + result: cancelResult, + invocationState, + }) + yield afterToolCallEvent + if (afterToolCallEvent.retry) { + continue + } + return afterToolCallEvent.result + } + + // Start tool span within loop span context + const toolSpan = this._tracer.startToolCallSpan({ + tool: toolUse, + }) + + // Track tool execution time for metrics + const toolStartTime = Date.now() + + let toolResult: ToolResultBlock + let error: Error | undefined + + if (!effectiveTool) { + // Tool not found + toolResult = new ToolResultBlock({ + toolUseId: toolUse.toolUseId, + status: 'error', + content: [new TextBlock(`Tool '${toolUse.name}' not found in registry`)], + }) + } else { + // Execute tool within the tool span context + const toolContext: ToolContext = { + toolUse: { + name: toolUse.name, + toolUseId: toolUse.toolUseId, + input: toolUse.input, + }, + agent: this, + invocationState, + interrupt: (params: InterruptParams): T => { + return interruptFromAgent(this, `tool:${toolUseBlock.toolUseId}:${params.name}`, params, 'tool') + }, + } + + try { + // Manually iterate tool stream to wrap each ToolStreamEvent in ToolStreamUpdateEvent. + // This keeps the tool authoring interface unchanged — tools construct ToolStreamEvent + // without knowledge of agents or hooks, and we wrap at the boundary. + // Tool execution is ran within the tool span's context so that + // downstream calls (e.g., MCP clients) can propagate trace context + const toolGenerator = this._tracer.withSpanContext(toolSpan, () => effectiveTool.stream(toolContext)) + let toolNext = await this._tracer.withSpanContext(toolSpan, () => toolGenerator.next()) + while (!toolNext.done) { + yield new ToolStreamUpdateEvent({ agent: this, event: toolNext.value, invocationState }) + toolNext = await this._tracer.withSpanContext(toolSpan, () => toolGenerator.next()) + } + const result = toolNext.value + + if (!result) { + // Tool didn't return a result + toolResult = new ToolResultBlock({ + toolUseId: toolUse.toolUseId, + status: 'error', + content: [new TextBlock(`Tool '${toolUse.name}' did not return a result`)], + }) + } else { + toolResult = result + error = result.error + } + } catch (e) { + // Re-throw InterruptError to allow interrupt handling + if (e instanceof InterruptError) { + throw e + } + // Tool execution failed with error + error = normalizeError(e) + toolResult = new ToolResultBlock({ + toolUseId: toolUse.toolUseId, + status: 'error', + content: [new TextBlock(error.message)], + error, + }) + } + } + + // End tool span with the raw tool result — telemetry reflects what the + // tool actually returned, independent of AfterToolCallEvent mutations. + this._tracer.endToolCallSpan(toolSpan, { toolResult, ...(error && { error }) }) + + // End tool metrics tracking + this._meter.endToolCall({ + tool: toolUse, + duration: Date.now() - toolStartTime, + success: toolResult.status === 'success', + }) + + // Single point for AfterToolCallEvent + const afterToolCallEvent = new AfterToolCallEvent({ + agent: this, + toolUse, + tool: effectiveTool, + result: toolResult, + invocationState, + ...(error !== undefined && { error }), + }) + yield afterToolCallEvent + + if (afterToolCallEvent.retry) { + continue + } + + // Return the (possibly mutated) result so hook transformations propagate + // to ToolResultEvent and the conversation message the model will see. + return afterToolCallEvent.result + } + } + + /** + * Redacts the last message in the conversation history. + * Called when guardrails block user input and redaction is enabled. + * + * Follows the redaction strategy: + * - If the message contains at least one toolResult block, all toolResult blocks + * are kept with redacted content, and all other blocks are discarded. + * - Otherwise, the entire content is replaced with a single text block containing + * the redaction message. + * + * @param redactMessage - The redaction message to replace the content with + */ + private _redactLastMessage(redactMessage: string): void { + // Find and redact the last message + const lastIndex = this.messages.length - 1 + if (lastIndex >= 0) { + const lastMessage = this.messages[lastIndex] + if (lastMessage && lastMessage.role === 'user') { + // Collect only tool result blocks with redacted content + const redactedContent: ContentBlock[] = [] + for (const block of lastMessage.content) { + if (block.type === 'toolResultBlock') { + // Preserve tool result block structure, only redact its content + redactedContent.push( + new ToolResultBlock({ + toolUseId: block.toolUseId, + status: block.status, + content: [new TextBlock(redactMessage)], + }) + ) + } + } + + // If no tool result blocks were found, replace entire content with redaction message + if (redactedContent.length === 0) { + redactedContent.push(new TextBlock(redactMessage)) + } + + this.messages[lastIndex] = new Message({ + role: 'user', + content: redactedContent, + }) + } else if (lastMessage) { + // Unexpected state: redaction requested but last message is not from user + logger.warn( + `role=<${lastMessage.role}> | received input redaction but last message is not from user | redaction skipped` + ) + } + } + } + + /** + * Estimate the input token count for the next model call. + * + * Uses the token counting strategy: reads inputTokens + outputTokens + * from the last assistant message's metadata as a known baseline, then estimates + * only new messages added after it. Falls back to full estimation when no metadata + * is available (cold start or first call). + * + * @param streamOptions - The stream options containing system prompt and tool specs + * @returns Estimated input token count + */ + private async _estimateInputTokens(streamOptions: StreamOptions): Promise { + // Find the last assistant message with usage metadata + let lastAssistantIdx = -1 + for (let i = this.messages.length - 1; i >= 0; i--) { + if (this.messages[i]!.role === 'assistant' && this.messages[i]!.metadata?.usage) { + lastAssistantIdx = i + break + } + } + + let estimate: number + if (lastAssistantIdx >= 0) { + const usage = this.messages[lastAssistantIdx]!.metadata!.usage! + const knownBaseline = usage.inputTokens + usage.outputTokens + const newMessages = this.messages.slice(lastAssistantIdx + 1) + if (newMessages.length === 0) { + estimate = knownBaseline + } else { + // System prompt and tool spec tokens are already included in the baseline from the prior model call + estimate = knownBaseline + (await this.model.countTokens(newMessages)) + } + } else { + estimate = await this.model.countTokens(this.messages, { + ...(streamOptions.systemPrompt !== undefined && { systemPrompt: streamOptions.systemPrompt }), + ...(streamOptions.toolSpecs !== undefined && { toolSpecs: streamOptions.toolSpecs }), + }) + } + + return estimate + } + + /** + * Appends a message to the conversation history and fires MessageAddedEvent hooks. + * + * Used by {@link ToolCaller} (via the helper passed to `ToolCaller.create`) for + * direct tool calls that cannot yield events into the agent stream. This stays + * private — callers outside the agent should never directly mutate messages. + */ + private async _appendMessageAndFireHooks(message: Message, invocationState: InvocationState = {}): Promise { + this.messages.push(message) + await this._hooksRegistry.invokeCallbacks(new MessageAddedEvent({ agent: this, message, invocationState })) + } + + /** + * Appends a message to the conversation history and returns the event for yielding. + * + * @param message - The message to append + * @returns MessageAddedEvent to be yielded + */ + private _appendMessage(message: Message, invocationState: InvocationState): MessageAddedEvent { + this.messages.push(message) + return new MessageAddedEvent({ agent: this, message, invocationState }) + } +} + +const INVALID_TOOL_NAME_PLACEHOLDER = 'INVALID_TOOL_NAME' + +/** + * Replaces invalid tool-use names on assistant messages with `INVALID_TOOL_NAME` + * so providers that reject malformed names don't fail the whole request. + * Returns the input unchanged (same reference) when nothing needs replacing. + */ +function normalizeToolUseNames(messages: Message[]): Message[] { + let replaced = false + const next = messages.map((message) => { + if (!message || message.role !== 'assistant') return message + + let messageReplaced = false + const content = message.content.map((block) => { + if (block.type !== 'toolUseBlock') return block + if (isValidToolName(block.name)) return block + messageReplaced = true + logger.debug(`tool_name=<${block.name}> | replacing invalid tool name with ${INVALID_TOOL_NAME_PLACEHOLDER}`) + return new ToolUseBlock({ + name: INVALID_TOOL_NAME_PLACEHOLDER, + toolUseId: block.toolUseId, + input: block.input, + ...(block.reasoningSignature !== undefined && { reasoningSignature: block.reasoningSignature }), + }) + }) + + if (!messageReplaced) return message + replaced = true + return new Message({ + role: message.role, + content, + ...(message.metadata !== undefined && { metadata: message.metadata }), + }) + }) + + return replaced ? next : messages +} + +/** + * Recursively flattens nested arrays of tools into a single flat array. + * @param tools - Tools or nested arrays of tools + * @returns Flat array of tools and MCP clients + */ +function flattenTools(toolList: ToolList): { tools: Tool[]; mcpClients: McpClient[] } { + const tools: Tool[] = [] + const mcpClients: McpClient[] = [] + + for (const item of toolList) { + if (Array.isArray(item)) { + const { tools: nestedTools, mcpClients: nestedMcpClients } = flattenTools(item) + tools.push(...nestedTools) + mcpClients.push(...nestedMcpClients) + } else if (item instanceof Agent) { + tools.push(item.asTool()) + } else if (item instanceof McpClient) { + mcpClients.push(item) + } else { + tools.push(item) + } + } + + return { tools, mcpClients } +} diff --git a/strands-ts/src/agent/printer.ts b/strands-ts/src/agent/printer.ts new file mode 100644 index 0000000000..bad81cc1c8 --- /dev/null +++ b/strands-ts/src/agent/printer.ts @@ -0,0 +1,234 @@ +import type { AgentStreamEvent } from '../types/agent.js' +import type { + ModelStreamEvent, + ModelContentBlockDeltaEventData, + ModelContentBlockStartEventData, +} from '../models/streaming.js' +import type { BeforeToolCallEvent, BeforeToolsEvent, ToolResultEvent } from '../hooks/events.js' + +/** + * Creates a default appender function for the current environment. + * Uses process.stdout.write in Node.js and console.log in browsers. + * @returns Appender function that writes text to the output destination + */ +export function getDefaultAppender(): (text: string) => void { + // Check if we're in Node.js environment with stdout + if (typeof process !== 'undefined' && process.stdout?.write) { + return (text: string) => process.stdout.write(text) + } + // Fall back to console.log for browser environment + return (text: string) => console.log(text) +} + +/** + * Interface for printing agent activity to a destination. + * Implementations can output to stdout, console, HTML elements, etc. + */ +export interface Printer { + /** + * Write content to the output destination. + * @param content - The content to write + */ + write(content: string): void + + /** + * Process a streaming event from the agent. + * @param event - The event to process + */ + processEvent(event: AgentStreamEvent): void +} + +/** + * Default implementation of the Printer interface. + * Outputs text, reasoning, and tool execution activity to the configured appender. + */ +export class AgentPrinter implements Printer { + private readonly _appender: (text: string) => void + private _inReasoningBlock: boolean = false + private _toolCount: number = 0 + private _needReasoningIndent: boolean = false + + /** + * Creates a new AgentPrinter. + * @param appender - Function that writes text to the output destination + */ + constructor(appender: (text: string) => void) { + this._appender = appender + } + + /** + * Write content to the output destination. + * @param content - The content to write + */ + public write(content: string): void { + this._appender(content) + } + + /** + * Process a streaming event from the agent. + * Handles text deltas, reasoning content, and tool execution events. + * @param event - The event to process + */ + public processEvent(event: AgentStreamEvent): void { + switch (event.type) { + case 'modelStreamUpdateEvent': + this.handleModelStreamEvent(event.event) + break + + case 'beforeToolCallEvent': + this.handleBeforeToolCall(event) + break + + case 'beforeToolsEvent': + this.handleBeforeTools(event) + break + + case 'toolResultEvent': + this.handleToolResult(event) + break + + case 'agentResultEvent': + this.write('\n') + break + + // Ignore other event types + default: + break + } + } + + /** + * Handle raw model stream events unwrapped from ModelStreamUpdateEvent. + */ + private handleModelStreamEvent(event: ModelStreamEvent): void { + switch (event.type) { + case 'modelContentBlockDeltaEvent': + this.handleContentBlockDelta(event) + break + case 'modelContentBlockStartEvent': + this.handleContentBlockStart(event) + break + case 'modelContentBlockStopEvent': + this.handleContentBlockStop() + break + default: + break + } + } + + /** + * Handle content block delta events (text or reasoning). + */ + private handleContentBlockDelta(event: ModelContentBlockDeltaEventData): void { + const { delta } = event + + if (delta.type === 'textDelta') { + // Output text immediately + if (delta.text && delta.text.length > 0) { + this.write(delta.text) + } + } else if (delta.type === 'reasoningContentDelta') { + // Start reasoning block if not already in one + if (!this._inReasoningBlock) { + this._inReasoningBlock = true + this._needReasoningIndent = true + this.write('\n💭 Reasoning:\n') + } + + // Stream reasoning text with proper indentation + if (delta.text && delta.text.length > 0) { + this.writeReasoningText(delta.text) + } + } + // Ignore toolUseInputDelta and other delta types + } + + /** + * Write reasoning text with proper indentation after newlines. + */ + private writeReasoningText(text: string): void { + let output = '' + + for (let i = 0; i < text.length; i++) { + const char = text[i] + + // Add indentation if needed (at start or after newline) + if (this._needReasoningIndent && char !== '\n') { + output += ' ' + this._needReasoningIndent = false + } + + output += char + + // Mark that we need indentation after a newline + if (char === '\n') { + this._needReasoningIndent = true + } + } + + this.write(output) + } + + /** + * Handle content block start events. + * Prints a subtle preview during streaming; the definitive announcement + * (with numbering and status icon) comes in beforeToolCallEvent after hooks resolve. + */ + private handleContentBlockStart(event: ModelContentBlockStartEventData): void { + if (event.start?.type === 'toolUseStart') { + this.write(`\n ⏳ ${event.start.name}\n`) + } + } + + /** + * Handle content block stop events. + * Closes reasoning blocks if we were in one. + */ + private handleContentBlockStop(): void { + if (this._inReasoningBlock) { + // End reasoning block with a newline if we didn't just write one + if (!this._needReasoningIndent) { + this.write('\n') + } + this._inReasoningBlock = false + this._needReasoningIndent = false + } + } + + /** + * Handle before-tool-call events. + * Announces the tool after hooks have resolved, so denied tools get a + * distinct indicator instead of looking like they executed. + */ + private handleBeforeToolCall(event: BeforeToolCallEvent): void { + this._toolCount++ + if (event.cancel) { + this.write(`\n🚫 Tool #${this._toolCount}: ${event.toolUse.name} (denied)\n`) + } else { + this.write(`\n🔧 Tool #${this._toolCount}: ${event.toolUse.name}\n`) + } + } + + /** + * Handle before-tools events. + * When all tools are batch-cancelled, prints a notice since no individual + * BeforeToolCallEvent will fire. + */ + private handleBeforeTools(event: BeforeToolsEvent): void { + if (event.cancel) { + this.write('\n🚫 All tools denied\n') + } + } + + /** + * Handle tool result events. + * Outputs completion status. + */ + private handleToolResult(event: ToolResultEvent): void { + if (event.result.status === 'success') { + this.write('✓ Tool completed\n') + } else if (event.result.status === 'error') { + this.write('✗ Tool failed\n') + } + } +} diff --git a/strands-ts/src/agent/snapshot.ts b/strands-ts/src/agent/snapshot.ts new file mode 100644 index 0000000000..18206d2393 --- /dev/null +++ b/strands-ts/src/agent/snapshot.ts @@ -0,0 +1,233 @@ +/** + * Snapshot helpers for agent state capture and restoration. + * + * These functions provide the shared implementation for {@link LocalAgent.takeSnapshot} + * and {@link LocalAgent.loadSnapshot}. Since all `LocalAgent` implementations share the + * same snapshot logic, these helpers avoid duplication. The canonical public API is + * `agent.takeSnapshot()` / `agent.loadSnapshot()` — prefer calling those directly. + */ + +import type { JSONValue } from '../types/json.js' +import type { MessageData, SystemPromptData } from '../types/messages.js' +import { Message, systemPromptFromData, systemPromptToData } from '../types/messages.js' +import { loadStateSerializable, serializeStateSerializable } from '../types/serializable.js' +import type { LocalAgent } from '../types/agent.js' +import { SNAPSHOT_SCHEMA_VERSION } from '../types/snapshot.js' +import type { Snapshot } from '../types/snapshot.js' +import { InterruptState, type InterruptStateData } from '../interrupt.js' + +/** + * All available fields that can be included in a snapshot. + */ +export const ALL_SNAPSHOT_FIELDS = ['messages', 'state', 'systemPrompt', 'modelState', 'interrupts'] as const + +/** + * Strongly typed preset definitions for snapshot field selection. + * This object allows easy evolution of presets and type-safe access. + */ +export const SNAPSHOT_PRESETS = { + session: ['messages', 'state', 'systemPrompt', 'modelState', 'interrupts'] as const, +} as const + +/** + * Preset name for snapshot field selection. + */ +export type SnapshotPreset = keyof typeof SNAPSHOT_PRESETS + +/** + * Valid snapshot field names. + */ +export type SnapshotField = (typeof ALL_SNAPSHOT_FIELDS)[number] + +/** + * Creates an ISO 8601 timestamp string. + * + * @returns Current timestamp in ISO 8601 format + */ +export function createTimestamp(): string { + return new Date().toISOString() +} + +/** + * Options for taking a snapshot of agent state. + */ +export type TakeSnapshotOptions = { + /** + * Preset to use as the starting set of fields. + * If not specified, starts with an empty set (unless include is specified). + */ + preset?: SnapshotPreset + /** + * Fields to add to the snapshot. + * These are added to the preset fields (if any). + */ + include?: SnapshotField[] + /** + * Fields to exclude from the snapshot. + * Applied after preset and include to filter out specific fields. + */ + exclude?: SnapshotField[] + /** + * Application-owned data to store in the snapshot. + * Strands does not read or modify this data. + */ + appData?: Record +} + +/** + * Shared implementation for {@link LocalAgent.takeSnapshot}. + * Prefer calling `agent.takeSnapshot(options)` directly. + * + * @param agent - The agent to snapshot + * @param options - Snapshot options + * @returns A snapshot of the agent's state + */ +export function takeSnapshot(agent: LocalAgent, options: TakeSnapshotOptions): Snapshot { + const fields = resolveSnapshotFields(options) + + const data: Record = {} + + if (fields.has('messages')) { + data.messages = agent.messages.map((msg) => msg.toJSON()) as unknown as JSONValue + } + + if (fields.has('state')) { + data.state = serializeStateSerializable(agent.appState) + } + + if (fields.has('systemPrompt')) { + data.systemPrompt = agent.systemPrompt !== undefined ? (systemPromptToData(agent.systemPrompt) as JSONValue) : null + } + + if (fields.has('modelState')) { + data.modelState = serializeStateSerializable(agent.modelState) + } + + if (fields.has('interrupts')) { + const interruptState = (agent as unknown as { _interruptState?: InterruptState })._interruptState + data.interrupts = interruptState ? (interruptState.toJSON() as unknown as JSONValue) : null + } + + return { + scope: 'agent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: createTimestamp(), + data, + appData: options.appData ?? {}, + } +} + +/** + * Shared implementation for {@link LocalAgent.loadSnapshot}. + * Prefer calling `agent.loadSnapshot(snapshot)` directly. + * + * @param agent - The agent to restore state into + * @param snapshot - The snapshot to load + */ +export function loadSnapshot(agent: LocalAgent, snapshot: Snapshot): void { + if (snapshot.scope !== 'agent') { + throw new Error(`Expected snapshot scope 'agent', got '${snapshot.scope}'`) + } + if (snapshot.schemaVersion !== SNAPSHOT_SCHEMA_VERSION) { + throw new Error( + `Unsupported snapshot schema version: ${snapshot.schemaVersion}. Current version: ${SNAPSHOT_SCHEMA_VERSION}` + ) + } + + if ('messages' in snapshot.data) { + const messages = snapshot.data.messages + agent.messages.length = 0 + for (const msgData of messages as unknown as MessageData[]) { + agent.messages.push(Message.fromJSON(msgData)) + } + } + + if ('state' in snapshot.data) { + loadStateSerializable(agent.appState, snapshot.data.state) + } + + // Use key-presence check to distinguish "field absent" (leave unchanged) from + // "field present as null" (agent had no system prompt — clear it). + if ('systemPrompt' in snapshot.data) { + const systemPrompt = snapshot.data.systemPrompt + if (systemPrompt !== null) { + agent.systemPrompt = systemPromptFromData(systemPrompt as SystemPromptData) + } else { + delete agent.systemPrompt + } + } + + if ('modelState' in snapshot.data) { + loadStateSerializable(agent.modelState, snapshot.data.modelState) + } + + if ('interrupts' in snapshot.data) { + const interruptStateData = snapshot.data.interrupts + if (interruptStateData !== null) { + const agentRecord = agent as unknown as { _interruptState: InterruptState } + agentRecord._interruptState = InterruptState.fromJSON(interruptStateData as unknown as InterruptStateData) + } + } +} + +/** + * Resolves snapshot fields based on preset/include/exclude parameters. + * + * Order of operations: + * 1. Start with preset fields (if specified) + * 2. Add include fields + * 3. Remove exclude fields + * + * @param options - Snapshot options containing preset, include, and exclude fields + * @returns Set of resolved field names + * @throws Error if no fields would be included + */ +export function resolveSnapshotFields(options: TakeSnapshotOptions = {}): Set { + const { preset, include, exclude } = options + let fields: Set + + // Start with preset fields or empty set + if (preset !== undefined) { + if (!(preset in SNAPSHOT_PRESETS)) { + throw new Error(`Invalid preset: ${preset}. Valid presets are: ${Object.keys(SNAPSHOT_PRESETS).join(', ')}`) + } + fields = new Set(SNAPSHOT_PRESETS[preset]) + } else { + fields = new Set() + } + + // Add include fields + if (include !== undefined) { + validateSnapshotFields(include) + for (const field of include) { + fields.add(field) + } + } + + // Remove exclude fields (no error if field wasn't included) + if (exclude !== undefined) { + validateSnapshotFields(exclude) + for (const field of exclude) { + fields.delete(field) + } + } + + // Must have at least one field + if (fields.size === 0) { + throw new Error('No fields to include in snapshot. Specify a preset or include fields.') + } + + return fields +} + +/** + * Validates that all field names are valid snapshot fields. + */ +function validateSnapshotFields(fields: string[]): void { + const validFields = new Set(ALL_SNAPSHOT_FIELDS) + for (const field of fields) { + if (!validFields.has(field)) { + throw new Error(`Invalid snapshot field: ${field}. Valid fields are: ${ALL_SNAPSHOT_FIELDS.join(', ')}`) + } + } +} diff --git a/strands-ts/src/agent/tool-caller.ts b/strands-ts/src/agent/tool-caller.ts new file mode 100644 index 0000000000..a2d750fe2b --- /dev/null +++ b/strands-ts/src/agent/tool-caller.ts @@ -0,0 +1,294 @@ +/** + * Direct tool calling support through agent.tool accessor. + * + * Enables method-style tool invocation without model inference: + * ```typescript + * const agent = new Agent({ tools: [myTool] }) + * const result = await agent.tool.calculator!.invoke({ a: 5, b: 3 }) + * ``` + */ + +import type { JSONValue } from '../types/json.js' +import type { ToolResultBlock } from '../types/messages.js' +import { Message } from '../types/messages.js' +import { TextBlock, ToolUseBlock } from '../types/messages.js' +import type { InvocationState } from '../types/agent.js' +import type { Tool, ToolContext } from '../tools/tool.js' +import { ToolStreamEvent } from '../tools/tool.js' +import type { ToolUse } from '../tools/types.js' +import type { Agent } from './agent.js' +import { ConcurrentInvocationError } from '../errors.js' + +/** + * Options for direct tool call execution. + */ +export interface DirectToolCallOptions { + /** + * Whether to record this tool call in the agent's message history. + * Defaults to `true`. Set to `false` to execute the tool without + * affecting conversation context. + */ + recordDirectToolCall?: boolean +} + +/** + * A handle to a specific tool, providing `.invoke()` and `.stream()` methods. + * + * Returned by the Proxy get trap when accessing `agent.tool.toolName`. + * This aligns with the agent-level `agent.invoke()` / `agent.stream()` pattern. + */ +export interface ToolHandle { + /** + * Invoke the tool and return the final result. + * + * @param input - The input parameters for the tool + * @param options - Optional configuration for this call + * @returns The tool result + */ + invoke: (input?: JSONValue, options?: DirectToolCallOptions) => Promise + + /** + * Stream the tool execution, yielding intermediate events and returning the final result. + * + * @param input - The input parameters for the tool + * @param options - Optional configuration for this call + * @returns Async generator that yields ToolStreamEvents and returns ToolResultBlock + */ + stream: ( + input?: JSONValue, + options?: DirectToolCallOptions + ) => AsyncGenerator +} + +/** + * The public type of the tool caller proxy. + * Provides dynamic property access where each property is a {@link ToolHandle} + * with `.invoke()` and `.stream()` methods. + */ +export type ToolCallerProxy = Record + +/** + * Helper passed in from Agent for appending messages and firing MessageAddedEvent hooks. + * + * Defined here (not in agent.ts) so that the message-mutation capability stays + * encapsulated — only the Agent knows how to mutate messages safely, and it + * passes a bound helper into ToolCaller. ToolCaller never gets direct access + * to `agent.messages` or the hooks registry. + */ +export type AppendMessageFn = (message: Message, invocationState?: InvocationState) => Promise + +/** + * Provides direct tool calling through the agent. + * + * Enables programmatic tool invocation without model inference via + * `agent.tool.toolName.invoke(input)` or `agent.tool.toolName.stream(input)`. + * Tools are called directly, bypassing the model loop, and results are optionally + * recorded in message history for context continuity. + * + * Supports underscore-to-hyphen and case-insensitive name normalization + * via {@link ToolRegistry.resolve}. + * + * @example + * ```typescript + * const agent = new Agent({ tools: [calculatorTool] }) + * + * // Invoke and get the result + * const result = await agent.tool.calculator!.invoke({ operation: 'add', a: 5, b: 3 }) + * console.log(result.status) // 'success' + * + * // Stream intermediate events + * for await (const event of agent.tool.calculator!.stream({ operation: 'add', a: 5, b: 3 })) { + * console.log('progress:', event) + * } + * ``` + * + * @internal This class is not intended for direct instantiation by users. + */ +export class ToolCaller { + private readonly _agent: Agent + private readonly _appendMessage: AppendMessageFn + + /** + * Creates a ToolCaller proxy for the given agent. + * + * Encapsulates the Proxy cast so callers don't need to handle the + * implementation detail that the constructor returns a Proxy, not + * a plain ToolCaller instance. + * + * @param agent - The owning agent instance + * @param appendMessage - Helper provided by the agent to append messages and fire hooks. + * Passed in (rather than calling a public agent method) so message mutation stays + * encapsulated within the agent. + */ + static create(agent: Agent, appendMessage: AppendMessageFn): ToolCallerProxy { + return new ToolCaller(agent, appendMessage) as unknown as ToolCallerProxy + } + + private constructor(agent: Agent, appendMessage: AppendMessageFn) { + this._agent = agent + this._appendMessage = appendMessage + + // Return a Proxy that intercepts property access to resolve tool names + return new Proxy(this, { + get(target: ToolCaller, prop: string | symbol, receiver: unknown): ToolHandle | unknown { + // Pass through symbol properties (Symbol.toPrimitive, Symbol.iterator, etc.) + // Uses Reflect.get for proper receiver forwarding. + if (typeof prop === 'symbol') { + return Reflect.get(target, prop, receiver) + } + + // Prevent accidental thenable behavior — if a user writes `await agent.tool` + // the JS runtime checks for `.then`. Without this guard, the Proxy would return + // a ToolHandle for a non-existent tool named "then", which is confusing. + // Note: this means a tool literally named "then" cannot be accessed via this proxy. + if (prop === 'then') { + return undefined + } + + // Return a ToolHandle with .invoke() and .stream() for the named tool. + // We intentionally do NOT fall through to `prop in target` here — that would + // cause tool names that collide with inherited Object properties (e.g., + // 'constructor', 'toString', 'valueOf') to return the wrong value. + return target._createToolHandle(prop) + }, + }) + } + + /** + * Creates a ToolHandle for the given tool name. + */ + private _createToolHandle(name: string): ToolHandle { + return { + invoke: (input?: JSONValue, options?: DirectToolCallOptions): Promise => { + return this._callTool(name, input ?? {}, options) + }, + stream: ( + input?: JSONValue, + options?: DirectToolCallOptions + ): AsyncGenerator => { + return this._streamTool(name, input ?? {}, options) + }, + } + } + + /** + * Executes a tool by name with the given input, consuming the full stream and returning the result. + * + * @param name - The tool name (supports underscore-to-hyphen and case-insensitive resolution) + * @param input - The input parameters for the tool + * @param options - Optional configuration for this call + * @returns The tool result + */ + private async _callTool(name: string, input: JSONValue, options?: DirectToolCallOptions): Promise { + const gen = this._streamTool(name, input, options) + let result = await gen.next() + while (!result.done) { + result = await gen.next() + } + return result.value + } + + /** + * Streams a tool execution by name, yielding intermediate events. + * + * @param name - The tool name + * @param input - The input parameters for the tool + * @param options - Optional configuration for this call + * @returns Async generator that yields ToolStreamEvents and returns ToolResultBlock + */ + private async *_streamTool( + name: string, + input: JSONValue, + options?: DirectToolCallOptions + ): AsyncGenerator { + const shouldRecord = options?.recordDirectToolCall ?? true + + // If recording, check that the agent is not currently invoking + if (shouldRecord && this._agent.isInvoking) { + throw new ConcurrentInvocationError( + 'Direct tool call cannot be made while the agent is in the middle of an invocation. ' + + 'Set recordDirectToolCall: false to allow direct tool calls during agent invocation.' + ) + } + + // Resolve the tool via the registry's normalization (exact → hyphen → case-insensitive) + const tool = this._agent.toolRegistry.resolve(name) + + // Generate unique tool use ID + const toolUseId = `tooluse_${globalThis.crypto.randomUUID()}` + const toolUse: ToolUse = { + toolUseId, + name: tool.name, + input, + } + + // Create tool context + const toolContext: ToolContext = { + toolUse, + agent: this._agent, + invocationState: {}, + interrupt: (): never => { + throw new Error('Interrupts are not supported in direct tool calls') + }, + } + + // Execute the tool, yielding stream events + const toolResult = yield* this._executeTool(tool, toolContext) + + // Record in message history if configured + if (shouldRecord) { + await this._recordToolExecution(toolUse, toolResult) + } + + return toolResult + } + + /** + * Executes a tool's stream generator, yielding events and returning the final result. + */ + private async *_executeTool( + tool: Tool, + toolContext: ToolContext + ): AsyncGenerator { + const generator = tool.stream(toolContext) + let result = await generator.next() + while (!result.done) { + yield result.value + result = await generator.next() + } + return result.value + } + + /** + * Records a tool execution in the agent's message history and fires MessageAddedEvent hooks. + * + * Creates a sequence of 3 messages that represent the tool execution: + * 1. An assistant message with the ToolUseBlock (what was called and with what input) + * 2. A user message with the ToolResultBlock (tool output) + * 3. An assistant message acknowledging the result + * + * Each message fires a {@link MessageAddedEvent} so that hooks registered via + * `agent.addHook(MessageAddedEvent, ...)` are notified of direct tool call messages. + */ + private async _recordToolExecution(toolUse: ToolUse, toolResult: ToolResultBlock): Promise { + const toolUseBlock = new ToolUseBlock({ + toolUseId: toolUse.toolUseId, + name: toolUse.name, + input: toolUse.input, + }) + + const toolUseMsg = new Message({ role: 'assistant', content: [toolUseBlock] }) + const toolResultMsg = new Message({ role: 'user', content: [toolResult] }) + const assistantMsg = new Message({ + role: 'assistant', + content: [new TextBlock(`agent.tool.${toolUse.name} was called.`)], + }) + + // Append messages and fire MessageAddedEvent hooks for each, using the + // helper provided by Agent. This keeps message mutation encapsulated in + // the agent — ToolCaller never touches `agent.messages` directly. + await this._appendMessage(toolUseMsg) + await this._appendMessage(toolResultMsg) + await this._appendMessage(assistantMsg) + } +} diff --git a/strands-ts/src/conversation-manager/__tests__/conversation-manager.test.ts b/strands-ts/src/conversation-manager/__tests__/conversation-manager.test.ts new file mode 100644 index 0000000000..e5106e0d95 --- /dev/null +++ b/strands-ts/src/conversation-manager/__tests__/conversation-manager.test.ts @@ -0,0 +1,389 @@ +import { describe, it, expect, vi } from 'vitest' +import { + ConversationManager, + type ConversationManagerReduceOptions, + type ConversationManagerOptions, +} from '../conversation-manager.js' +import { NullConversationManager } from '../null-conversation-manager.js' +import { Agent } from '../../agent/agent.js' +import { Message, TextBlock } from '../../index.js' +import { AfterModelCallEvent, BeforeModelCallEvent } from '../../hooks/events.js' +import { ContextWindowOverflowError } from '../../errors.js' +import { createMockAgent, invokeTrackedHook } from '../../__fixtures__/agent-helpers.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import type { BaseModelConfig } from '../../models/model.js' +import { warnOnce } from '../../logging/warn-once.js' + +vi.mock('../../logging/warn-once.js', () => ({ + warnOnce: vi.fn(), +})) + +class TestConversationManager extends ConversationManager { + readonly name = 'test:conversation-manager' + reduceCallCount = 0 + shouldReduce = true + + constructor(options?: ConversationManagerOptions) { + super(options) + } + + reduce({ agent }: ConversationManagerReduceOptions): boolean { + this.reduceCallCount++ + if (!this.shouldReduce) return false + agent.messages.splice(0, 1) + return true + } +} + +class ThresholdTestManager extends ConversationManager { + readonly name = 'test:threshold-manager' + reduceCallCount = 0 + shouldReduce = true + + constructor(options?: ConversationManagerOptions) { + super(options) + } + + reduce({ agent }: ConversationManagerReduceOptions): boolean { + this.reduceCallCount++ + if (!this.shouldReduce) return false + agent.messages.splice(0, 1) + return true + } +} + +describe('ConversationManager', () => { + describe('initAgent', () => { + it('registers both AfterModelCallEvent and BeforeModelCallEvent hooks', () => { + const manager = new TestConversationManager() + const mockAgent = createMockAgent() + manager.initAgent(mockAgent) + + // Always registers both hooks now + expect(mockAgent.trackedHooks).toHaveLength(2) + expect(mockAgent.trackedHooks[0]!.eventType).toBe(AfterModelCallEvent) + expect(mockAgent.trackedHooks[1]!.eventType).toBe(BeforeModelCallEvent) + }) + + it('calls reduce and sets retry=true on ContextWindowOverflowError when reduce returns true', async () => { + const manager = new TestConversationManager() + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + ] + const mockAgent = createMockAgent({ messages }) + manager.initAgent(mockAgent) + + const error = new ContextWindowOverflowError('overflow') + const event = new AfterModelCallEvent({ + agent: mockAgent, + model: {} as any, + attemptCount: 1, + error, + invocationState: {}, + }) + await invokeTrackedHook(mockAgent, event) + + expect(manager.reduceCallCount).toBe(1) + expect(event.retry).toBe(true) + expect(mockAgent.messages).toHaveLength(1) + }) + + it('does not set retry when reduce returns false', async () => { + const manager = new TestConversationManager() + manager.shouldReduce = false + const mockAgent = createMockAgent() + manager.initAgent(mockAgent) + + const error = new ContextWindowOverflowError('overflow') + const event = new AfterModelCallEvent({ + agent: mockAgent, + model: {} as any, + attemptCount: 1, + error, + invocationState: {}, + }) + await invokeTrackedHook(mockAgent, event) + + expect(manager.reduceCallCount).toBe(1) + expect(event.retry).toBeUndefined() + }) + + it('does not call reduce for non-overflow errors', async () => { + const manager = new TestConversationManager() + const mockAgent = createMockAgent() + manager.initAgent(mockAgent) + + const error = new Error('some other error') + const event = new AfterModelCallEvent({ + agent: mockAgent, + model: {} as any, + attemptCount: 1, + error, + invocationState: {}, + }) + await invokeTrackedHook(mockAgent, event) + + expect(manager.reduceCallCount).toBe(0) + expect(event.retry).toBeUndefined() + }) + + it('passes error to reduce when called due to overflow', async () => { + const receivedArgs: ConversationManagerReduceOptions[] = [] + class CapturingManager extends ConversationManager { + readonly name = 'test:capturing' + reduce(args: ConversationManagerReduceOptions): boolean { + receivedArgs.push(args) + return false + } + } + + const manager = new CapturingManager() + const mockAgent = createMockAgent() + manager.initAgent(mockAgent) + + const error = new ContextWindowOverflowError('overflow') + const event = new AfterModelCallEvent({ + agent: mockAgent, + model: {} as any, + attemptCount: 1, + error, + invocationState: {}, + }) + await invokeTrackedHook(mockAgent, event) + + expect(receivedArgs).toHaveLength(1) + expect(receivedArgs[0]!.error).toBe(error) + expect(receivedArgs[0]!.agent).toBe(mockAgent) + }) + }) + + describe('proactiveCompression', () => { + const mockModel = { getConfig: () => ({ contextWindowLimit: 1000 }) as BaseModelConfig } as any + + it('always registers a BeforeModelCallEvent hook regardless of proactiveCompression setting', () => { + const manager = new TestConversationManager() + const mockAgent = createMockAgent() + manager.initAgent(mockAgent) + + // Both hooks always registered + expect(mockAgent.trackedHooks).toHaveLength(2) + expect(mockAgent.trackedHooks[0]!.eventType).toBe(AfterModelCallEvent) + expect(mockAgent.trackedHooks[1]!.eventType).toBe(BeforeModelCallEvent) + }) + + it('BeforeModelCallEvent handler is a no-op when proactiveCompression is not set', async () => { + const manager = new ThresholdTestManager() + const mockAgent = createMockAgent() + manager.initAgent(mockAgent) + + const event = new BeforeModelCallEvent({ + agent: mockAgent, + model: mockModel, + invocationState: {}, + projectedInputTokens: 900, // Would exceed any threshold + }) + await invokeTrackedHook(mockAgent, event) + + expect(manager.reduceCallCount).toBe(0) + }) + + it('uses default threshold of 0.7 when proactiveCompression is true', async () => { + const manager = new ThresholdTestManager({ proactiveCompression: true }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + ] + const mockAgent = createMockAgent({ messages }) + manager.initAgent(mockAgent) + + // 650/1000 = 0.65 < 0.7 — should NOT trigger + const event = new BeforeModelCallEvent({ + agent: mockAgent, + model: mockModel, + invocationState: {}, + projectedInputTokens: 650, + }) + await invokeTrackedHook(mockAgent, event) + expect(manager.reduceCallCount).toBe(0) + + // 800/1000 = 0.8 >= 0.7 — should trigger + const event2 = new BeforeModelCallEvent({ + agent: mockAgent, + model: mockModel, + invocationState: {}, + projectedInputTokens: 800, + }) + await invokeTrackedHook(mockAgent, event2) + expect(manager.reduceCallCount).toBe(1) + }) + + it('calls reduce without error when projected tokens exceed custom threshold', async () => { + const receivedArgs: ConversationManagerReduceOptions[] = [] + class CapturingManager extends ConversationManager { + readonly name = 'test:capturing-threshold' + reduce(args: ConversationManagerReduceOptions): boolean { + receivedArgs.push(args) + args.agent.messages.splice(0, 1) + return true + } + } + + const manager = new CapturingManager({ proactiveCompression: { compressionThreshold: 0.5 } }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + ] + const mockAgent = createMockAgent({ messages }) + manager.initAgent(mockAgent) + + const event = new BeforeModelCallEvent({ + agent: mockAgent, + model: mockModel, + invocationState: {}, + projectedInputTokens: 600, // 600/1000 = 0.6 >= 0.5 + }) + await invokeTrackedHook(mockAgent, event) + + expect(receivedArgs).toHaveLength(1) + expect(receivedArgs[0]!.error).toBeUndefined() + expect(receivedArgs[0]!.model).toBe(mockModel) + expect(receivedArgs[0]!.agent).toBe(mockAgent) + }) + + it('does not call reduce when below threshold', async () => { + const manager = new ThresholdTestManager({ proactiveCompression: { compressionThreshold: 0.7 } }) + const mockAgent = createMockAgent() + manager.initAgent(mockAgent) + + const event = new BeforeModelCallEvent({ + agent: mockAgent, + model: mockModel, + invocationState: {}, + projectedInputTokens: 500, // 500/1000 = 0.5 < 0.7 + }) + await invokeTrackedHook(mockAgent, event) + + expect(manager.reduceCallCount).toBe(0) + }) + + it('does not call reduce when projectedInputTokens is undefined', async () => { + const manager = new ThresholdTestManager({ proactiveCompression: true }) + const mockAgent = createMockAgent() + manager.initAgent(mockAgent) + + const event = new BeforeModelCallEvent({ + agent: mockAgent, + model: mockModel, + invocationState: {}, + }) + await invokeTrackedHook(mockAgent, event) + + expect(manager.reduceCallCount).toBe(0) + }) + + it('uses 200k default when contextWindowLimit is undefined and logs warning', async () => { + const manager = new ThresholdTestManager({ proactiveCompression: { compressionThreshold: 0.7 } }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + ] + const mockAgent = createMockAgent({ messages }) + manager.initAgent(mockAgent) + + const modelWithoutLimit = { getConfig: () => ({}) as BaseModelConfig } as any + // 150000/200000 = 0.75 >= 0.7 — should trigger with the 200k default + const event = new BeforeModelCallEvent({ + agent: mockAgent, + model: modelWithoutLimit, + invocationState: {}, + projectedInputTokens: 150000, + }) + await invokeTrackedHook(mockAgent, event) + + expect(manager.reduceCallCount).toBe(1) + expect(warnOnce).toHaveBeenCalledWith( + expect.anything(), + expect.stringContaining('contextWindowLimit is not set on the model, using default of 200000') + ) + }) + + it('does not trigger with 200k default when below threshold', async () => { + const manager = new ThresholdTestManager({ proactiveCompression: { compressionThreshold: 0.7 } }) + const mockAgent = createMockAgent() + manager.initAgent(mockAgent) + + const modelWithoutLimit = { getConfig: () => ({}) as BaseModelConfig } as any + // 100000/200000 = 0.5 < 0.7 — should NOT trigger + const event = new BeforeModelCallEvent({ + agent: mockAgent, + model: modelWithoutLimit, + invocationState: {}, + projectedInputTokens: 100000, + }) + await invokeTrackedHook(mockAgent, event) + + expect(manager.reduceCallCount).toBe(0) + }) + + it('swallows errors from proactive reduce and continues', async () => { + class ThrowingManager extends ConversationManager { + readonly name = 'test:throwing' + reduce({ error }: ConversationManagerReduceOptions): boolean { + if (!error) { + throw new Error('proactive compression exploded') + } + return false + } + } + + const manager = new ThrowingManager({ proactiveCompression: true }) + const mockAgent = createMockAgent() + manager.initAgent(mockAgent) + + const event = new BeforeModelCallEvent({ + agent: mockAgent, + model: mockModel, + invocationState: {}, + projectedInputTokens: 800, + }) + + // Should not throw — error is swallowed + await expect(invokeTrackedHook(mockAgent, event)).resolves.toBeUndefined() + }) + + it('throws on compressionThreshold <= 0', () => { + expect(() => new ThresholdTestManager({ proactiveCompression: { compressionThreshold: 0 } })).toThrow( + 'must be between 0 (exclusive) and 1 (inclusive)' + ) + expect(() => new ThresholdTestManager({ proactiveCompression: { compressionThreshold: -1 } })).toThrow( + 'must be between 0 (exclusive) and 1 (inclusive)' + ) + }) + + it('throws on compressionThreshold > 1', () => { + expect(() => new ThresholdTestManager({ proactiveCompression: { compressionThreshold: 1.5 } })).toThrow( + 'must be between 0 (exclusive) and 1 (inclusive)' + ) + }) + + it('accepts compressionThreshold of exactly 1', () => { + expect(() => new ThresholdTestManager({ proactiveCompression: { compressionThreshold: 1 } })).not.toThrow() + }) + }) +}) + +describe('overflow propagation', () => { + it('propagates ContextWindowOverflowError out of the agent loop when reduce returns false', async () => { + const model = new MockMessageModel() + model.addTurn(new ContextWindowOverflowError('context window exceeded')) + + const agent = new Agent({ + model, + conversationManager: new NullConversationManager(), + printer: false, + }) + + await expect(agent.invoke('hello')).rejects.toThrow(ContextWindowOverflowError) + }) +}) diff --git a/strands-ts/src/conversation-manager/__tests__/null-conversation-manager.test.ts b/strands-ts/src/conversation-manager/__tests__/null-conversation-manager.test.ts new file mode 100644 index 0000000000..86b3143231 --- /dev/null +++ b/strands-ts/src/conversation-manager/__tests__/null-conversation-manager.test.ts @@ -0,0 +1,72 @@ +import { describe, it, expect } from 'vitest' +import { NullConversationManager } from '../null-conversation-manager.js' +import { Message, TextBlock } from '../../index.js' +import { AfterModelCallEvent, BeforeModelCallEvent } from '../../hooks/events.js' +import { ContextWindowOverflowError } from '../../errors.js' +import { createMockAgent, invokeTrackedHook } from '../../__fixtures__/agent-helpers.js' + +describe('NullConversationManager', () => { + describe('behavior', () => { + it('does not modify conversation history on overflow', async () => { + const manager = new NullConversationManager() + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Hello')] }), + new Message({ role: 'assistant', content: [new TextBlock('Hi there')] }), + ] + const mockAgent = createMockAgent({ messages }) + manager.initAgent(mockAgent) + + const error = new ContextWindowOverflowError('Context overflow') + const event = new AfterModelCallEvent({ + agent: mockAgent, + model: {} as any, + attemptCount: 1, + error, + invocationState: {}, + }) + await invokeTrackedHook(mockAgent, event) + + // Messages should be unchanged — NullConversationManager never reduces + expect(mockAgent.messages).toHaveLength(2) + expect(mockAgent.messages[0]!.content[0]).toEqual({ type: 'textBlock', text: 'Hello' }) + expect(mockAgent.messages[1]!.content[0]).toEqual({ type: 'textBlock', text: 'Hi there' }) + }) + + it('does not set retry on context overflow', async () => { + const manager = new NullConversationManager() + const mockAgent = createMockAgent() + manager.initAgent(mockAgent) + + const error = new ContextWindowOverflowError('Context overflow') + const event = new AfterModelCallEvent({ + agent: mockAgent, + model: {} as any, + attemptCount: 1, + error, + invocationState: {}, + }) + await invokeTrackedHook(mockAgent, event) + + // reduce() returns false, so retry should not be set + expect(event.retry).toBeUndefined() + }) + + it('registers both the overflow recovery and proactive compression hooks', () => { + const manager = new NullConversationManager() + const mockAgent = createMockAgent() + manager.initAgent(mockAgent) + + // Base class always registers both hooks + expect(mockAgent.trackedHooks).toHaveLength(2) + expect(mockAgent.trackedHooks[0]!.eventType).toBe(AfterModelCallEvent) + expect(mockAgent.trackedHooks[1]!.eventType).toBe(BeforeModelCallEvent) + }) + }) + + describe('name', () => { + it('returns the plugin name', () => { + const manager = new NullConversationManager() + expect(manager.name).toBe('strands:null-conversation-manager') + }) + }) +}) diff --git a/strands-ts/src/conversation-manager/__tests__/sliding-window-conversation-manager.test.ts b/strands-ts/src/conversation-manager/__tests__/sliding-window-conversation-manager.test.ts new file mode 100644 index 0000000000..836f2c66fd --- /dev/null +++ b/strands-ts/src/conversation-manager/__tests__/sliding-window-conversation-manager.test.ts @@ -0,0 +1,1214 @@ +import { describe, it, expect, vi } from 'vitest' +import { SlidingWindowConversationManager } from '../sliding-window-conversation-manager.js' +import { + ContextWindowOverflowError, + DocumentBlock, + ImageBlock, + JsonBlock, + Message, + TextBlock, + ToolUseBlock, + ToolResultBlock, + VideoBlock, + type Model, +} from '../../index.js' +import { AfterInvocationEvent, AfterModelCallEvent, BeforeModelCallEvent } from '../../hooks/events.js' +import { createMockAgent, invokeTrackedHook } from '../../__fixtures__/agent-helpers.js' +import type { Agent } from '../../agent/agent.js' +import type { BaseModelConfig } from '../../models/model.js' + +async function triggerSlidingWindow(manager: SlidingWindowConversationManager, agent: Agent): Promise { + const pluginAgent = createMockAgent() + manager.initAgent(pluginAgent) + await invokeTrackedHook(pluginAgent, new AfterInvocationEvent({ agent, invocationState: {} })) +} + +// Helper to trigger context overflow handling through hooks +async function triggerContextOverflow( + manager: SlidingWindowConversationManager, + agent: Agent, + error: Error +): Promise<{ retry?: boolean }> { + const pluginAgent = createMockAgent() + manager.initAgent(pluginAgent) + const event = new AfterModelCallEvent({ agent, model: {} as any, attemptCount: 1, error, invocationState: {} }) + await invokeTrackedHook(pluginAgent, event) + return event +} + +describe('SlidingWindowConversationManager', () => { + describe('constructor', () => { + it('sets default windowSize to 40', () => { + const manager = new SlidingWindowConversationManager() + // Access through type assertion since these are private + expect((manager as any)._windowSize).toBe(40) + }) + + it('sets default shouldTruncateResults to true', () => { + const manager = new SlidingWindowConversationManager() + expect((manager as any)._shouldTruncateResults).toBe(true) + }) + + it('accepts custom windowSize', () => { + const manager = new SlidingWindowConversationManager({ windowSize: 20 }) + expect((manager as any)._windowSize).toBe(20) + }) + + it('accepts custom shouldTruncateResults', () => { + const manager = new SlidingWindowConversationManager({ shouldTruncateResults: false }) + expect((manager as any)._shouldTruncateResults).toBe(false) + }) + }) + + describe('reduce', () => { + it('returns true when tool results are truncated even though message count is unchanged', () => { + const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('x'.repeat(500))], + }), + ], + }), + ] + + const result = manager.reduce({ + agent: createMockAgent({ messages }), + model: {} as Model, + error: new ContextWindowOverflowError('overflow'), + }) + + expect(result).toBe(true) + expect(messages).toHaveLength(1) // length unchanged, but truncation occurred + }) + + it('returns true when messages are trimmed', () => { + const manager = new SlidingWindowConversationManager({ windowSize: 2, shouldTruncateResults: false }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + new Message({ role: 'user', content: [new TextBlock('Message 2')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 2')] }), + ] + + const result = manager.reduce({ + agent: createMockAgent({ messages }), + model: {} as Model, + error: new ContextWindowOverflowError('overflow'), + }) + + expect(result).toBe(true) + expect(messages).toHaveLength(2) + }) + }) + + describe('applyManagement', () => { + it('skips reduction when message count is less than window size', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 10 }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + ] + const mockAgent = createMockAgent({ messages }) + + await triggerSlidingWindow(manager, mockAgent) + + expect(mockAgent.messages).toHaveLength(2) + }) + + it('skips reduction when message count equals window size', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 2 }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + ] + const mockAgent = createMockAgent({ messages }) + + await triggerSlidingWindow(manager, mockAgent) + + expect(mockAgent.messages).toHaveLength(2) + }) + + it('removes all messages when windowSize is 0', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 0 }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + ] + const mockAgent = createMockAgent({ messages }) + + await triggerSlidingWindow(manager, mockAgent) + + expect(mockAgent.messages).toHaveLength(0) + }) + + it('calls reduceContext when message count exceeds window size', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 2 }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + new Message({ role: 'user', content: [new TextBlock('Message 2')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 2')] }), + ] + const mockAgent = createMockAgent({ messages }) + + await triggerSlidingWindow(manager, mockAgent) + + // Should have trimmed; first message must be user + expect(mockAgent.messages).toHaveLength(2) + expect(mockAgent.messages[0]!.role).toBe('user') + }) + }) + + describe('reduceContext - tool result truncation', () => { + it('partially truncates large tool results preserving first and last 200 chars', async () => { + const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) + const middle = 'MIDDLE_CONTENT_TO_REMOVE'.repeat(10) // 240 chars, safely above MIN_TRUNCATION_GAIN + const original = 'A'.repeat(200) + middle + 'B'.repeat(200) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock(original)], + }), + ], + }), + ] + const mockAgent = createMockAgent({ messages }) + + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) + + const expectedText = `${'A'.repeat(200)}\n\n${'B'.repeat(200)}` + expect(messages[0]!.content[0]).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock(expectedText)], + }) + ) + }) + + it('leaves small tool results unchanged', () => { + const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('Small result')], + }), + ], + }), + ] + + const result = (manager as any)._truncateToolResults(messages, 0) + expect(result).toBe(false) + }) + + it('finds oldest message with tool results', async () => { + const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) + const firstOriginal = 'F'.repeat(500) + const secondOriginal = 'S'.repeat(500) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock(firstOriginal)], + }), + ], + }), + new Message({ role: 'assistant', content: [new TextBlock('Response')] }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-2', + status: 'success', + content: [new TextBlock(secondOriginal)], + }), + ], + }), + ] + const mockAgent = createMockAgent({ messages }) + + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) + + // Oldest tool-result message is truncated; newer one is untouched. + const expectedTruncated = `${'F'.repeat(200)}\n\n${'F'.repeat(200)}` + expect(messages[1]!.content[0]).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock(expectedTruncated)], + }) + ) + expect(messages[3]!.content[0]).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-2', + status: 'success', + content: [new TextBlock(secondOriginal)], + }) + ) + }) + + it('returns after successful truncation without trimming messages', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 2, shouldTruncateResults: true }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('L'.repeat(500))], + }), + ], + }), + ] + const mockAgent = createMockAgent({ messages }) + + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) + + // Should not have removed any messages, only truncated tool result + expect(mockAgent.messages).toHaveLength(3) + }) + + it('skips truncation when shouldTruncateResults is false', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 3, shouldTruncateResults: false }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + new Message({ role: 'user', content: [new TextBlock('Message 2')] }), + new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'tool1', toolUseId: 'tool-1', input: {} })], + }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('L'.repeat(500))], + }), + ], + }), + ] + const mockAgent = createMockAgent({ messages }) + + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) + + // Should have trimmed messages instead of truncating tool result + expect(mockAgent.messages).toHaveLength(3) + expect(mockAgent.messages[0]!.role).toBe('user') + + // Tool result should not be truncated + expect(mockAgent.messages[2]!.content[0]).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('L'.repeat(500))], + }) + ) + }) + + it('does not re-truncate already-truncated results', async () => { + const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) + // Produced by an earlier run: 200 chars + marker + 200 chars = well under the 450-char + // threshold below which truncation is not worth running. + const alreadyTruncated = 'A'.repeat(200) + '\n\n' + 'B'.repeat(200) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock(alreadyTruncated)], + }), + ], + }), + ] + + // First call should return false (too short to gain anything from re-truncating) + const result = (manager as any)._truncateToolResults(messages, 0) + expect(result).toBe(false) + + // reduceContext should fall through to message trimming + const messages2 = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock(alreadyTruncated)], + }), + ], + }), + new Message({ role: 'assistant', content: [new TextBlock('Response')] }), + new Message({ role: 'user', content: [new TextBlock('Message')] }), + ] + const mockAgent = createMockAgent({ messages: messages2 }) + + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) + + // Should have trimmed messages since truncation was skipped + expect(mockAgent.messages.length).toBeLessThan(3) + }) + + it('replaces image blocks nested in tool results with descriptive placeholders', () => { + const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) + const bytes = new Uint8Array(1234) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new ImageBlock({ format: 'png', source: { bytes } }), new TextBlock('tail')], + }), + ], + }), + ] + + const changed = (manager as any)._truncateToolResults(messages, 0) + expect(changed).toBe(true) + + expect(messages[0]!.content[0]).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('[image: png, source: bytes, 1234 bytes]'), new TextBlock('tail')], + }) + ) + }) + + it('preserves the error field on truncated tool results', () => { + const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) + const originalError = new Error('tool blew up') + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'error', + content: [new TextBlock('x'.repeat(500))], + error: originalError, + }), + ], + }), + ] + + const changed = (manager as any)._truncateToolResults(messages, 0) + expect(changed).toBe(true) + + const expectedText = `${'x'.repeat(200)}\n\n${'x'.repeat(200)}` + expect(messages[0]!.content[0]).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'error', + content: [new TextBlock(expectedText)], + error: originalError, + }) + ) + }) + + it('image placeholder reflects non-bytes source kinds honestly', () => { + const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [ + new ImageBlock({ format: 'jpeg', source: { url: 'https://example.com/x.jpg' } }), + new ImageBlock({ format: 'png', source: { location: { type: 's3', uri: 's3://bucket/key' } } }), + ], + }), + ], + }), + ] + + ;(manager as any)._truncateToolResults(messages, 0) + + expect(messages[0]!.content[0]).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('[image: jpeg, source: url]'), new TextBlock('[image: png, source: s3]')], + }) + ) + }) + + it('replaces video bytes blocks with a descriptive placeholder', () => { + const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new VideoBlock({ format: 'mp4', source: { bytes: new Uint8Array(4096) } })], + }), + ], + }), + ] + + const changed = (manager as any)._truncateToolResults(messages, 0) + expect(changed).toBe(true) + expect(messages[0]!.content[0]).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('[video: mp4, source: bytes, 4096 bytes]')], + }) + ) + }) + + it('replaces video s3 blocks with a descriptive placeholder', () => { + const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [ + new VideoBlock({ + format: 'mp4', + source: { location: { type: 's3', uri: 's3://bucket/key' } }, + }), + ], + }), + ], + }), + ] + + const changed = (manager as any)._truncateToolResults(messages, 0) + expect(changed).toBe(true) + expect(messages[0]!.content[0]).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('[video: mp4, source: s3]')], + }) + ) + }) + + it('replaces document bytes blocks with a descriptive placeholder', () => { + const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [ + new DocumentBlock({ + name: 'report', + format: 'pdf', + source: { bytes: new Uint8Array(8192) }, + }), + ], + }), + ], + }), + ] + + const changed = (manager as any)._truncateToolResults(messages, 0) + expect(changed).toBe(true) + expect(messages[0]!.content[0]).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('[document: report, pdf, source: bytes, 8192 bytes]')], + }) + ) + }) + + it('replaces document s3 blocks with a descriptive placeholder', () => { + const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [ + new DocumentBlock({ + name: 'spec', + format: 'pdf', + source: { location: { type: 's3', uri: 's3://b/k' } }, + }), + ], + }), + ], + }), + ] + + const changed = (manager as any)._truncateToolResults(messages, 0) + expect(changed).toBe(true) + expect(messages[0]!.content[0]).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('[document: spec, pdf, source: s3]')], + }) + ) + }) + + it('partially truncates large text inside a document text source', () => { + const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) + const middle = 'M'.repeat(240) + const originalText = 'A'.repeat(200) + middle + 'B'.repeat(200) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new DocumentBlock({ name: 'report', format: 'txt', source: { text: originalText } })], + }), + ], + }), + ] + + const changed = (manager as any)._truncateToolResults(messages, 0) + expect(changed).toBe(true) + + const expectedText = `${'A'.repeat(200)}\n\n${'B'.repeat(200)}` + expect(messages[0]!.content[0]).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new DocumentBlock({ name: 'report', format: 'txt', source: { text: expectedText } })], + }) + ) + }) + + it('leaves small text inside a document text source unchanged', () => { + const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new DocumentBlock({ name: 'short', format: 'txt', source: { text: 'hello' } })], + }), + ], + }), + ] + + const changed = (manager as any)._truncateToolResults(messages, 0) + expect(changed).toBe(false) + }) + + it('truncates long nested text blocks inside a document content source', () => { + const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) + const longText = 'A'.repeat(200) + 'M'.repeat(240) + 'B'.repeat(200) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [ + new DocumentBlock({ + name: 'pages', + format: 'txt', + source: { content: [new TextBlock(longText), new TextBlock('short')] }, + }), + ], + }), + ], + }), + ] + + const changed = (manager as any)._truncateToolResults(messages, 0) + expect(changed).toBe(true) + + const expectedText = `${'A'.repeat(200)}\n\n${'B'.repeat(200)}` + expect(messages[0]!.content[0]).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [ + new DocumentBlock({ + name: 'pages', + format: 'txt', + source: { content: [new TextBlock(expectedText), new TextBlock('short')] }, + }), + ], + }) + ) + }) + + it('replaces large json blocks with a size placeholder', () => { + const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) + const big = { payload: 'x'.repeat(1000) } + const size = JSON.stringify(big).length + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new JsonBlock({ json: big })], + }), + ], + }), + ] + + const changed = (manager as any)._truncateToolResults(messages, 0) + expect(changed).toBe(true) + expect(messages[0]!.content[0]).toEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock(`[json: ${size} chars]`)], + }) + ) + }) + + it('leaves small json blocks unchanged', () => { + const manager = new SlidingWindowConversationManager({ shouldTruncateResults: true }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new JsonBlock({ json: { ok: true } })], + }), + ], + }), + ] + + const changed = (manager as any)._truncateToolResults(messages, 0) + expect(changed).toBe(false) + }) + + it('does not call truncateToolResults unless an error is passed in', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 2, shouldTruncateResults: true }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'tool1', toolUseId: 'id-1', input: {} })], + }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'id-1', + status: 'success', + content: [new TextBlock('Tool result content')], + }), + ], + }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + new Message({ role: 'user', content: [new TextBlock('Message 2')] }), + ] + const mockAgent = createMockAgent({ messages }) + + // Spy on _truncateToolResults to verify it's NOT called + const truncateSpy = vi.spyOn(manager as any, '_truncateToolResults') + + // Trigger window size enforcement (no error parameter) + await triggerSlidingWindow(manager, mockAgent) + + // Verify _truncateToolResults was NOT called during window enforcement + expect(truncateSpy).not.toHaveBeenCalled() + + // Should have trimmed; first message must be user + expect(mockAgent.messages.length).toBe(1) + expect(mockAgent.messages[0]!.role).toBe('user') + + truncateSpy.mockRestore() + }) + }) + + describe('reduceContext - message trimming', () => { + it('trims oldest messages when tool results cannot be truncated', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 3, shouldTruncateResults: false }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + new Message({ role: 'user', content: [new TextBlock('Message 2')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 2')] }), + new Message({ role: 'user', content: [new TextBlock('Message 3')] }), + ] + const mockAgent = createMockAgent({ messages }) + + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) + + expect(mockAgent.messages).toHaveLength(3) + expect(mockAgent.messages[0]!.content[0]!).toEqual({ type: 'textBlock', text: 'Message 2' }) + }) + + it('calculates correct trim index (messages.length - windowSize)', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 2 }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + new Message({ role: 'user', content: [new TextBlock('Message 2')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 2')] }), + ] + const mockAgent = createMockAgent({ messages }) + + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) + + // Should remove 2 messages (4 - 2 = 2) + expect(mockAgent.messages).toHaveLength(2) + }) + + it('removes all messages when windowSize is 0 on context overflow', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 0, shouldTruncateResults: false }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + ] + const mockAgent = createMockAgent({ messages }) + + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) + + expect(mockAgent.messages).toHaveLength(0) + }) + + it('uses default trim index of 2 when messages <= windowSize', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 5 }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + new Message({ role: 'user', content: [new TextBlock('Message 2')] }), + ] + const mockAgent = createMockAgent({ messages }) + + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) + + // Should remove 2 messages (default when count <= windowSize) + expect(mockAgent.messages).toHaveLength(1) + }) + + it('removes messages from start of array using splice', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 2 }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + new Message({ role: 'user', content: [new TextBlock('Message 2')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 2')] }), + ] + const mockAgent = createMockAgent({ messages }) + + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) + + // Should keep last 2 messages + expect(mockAgent.messages).toHaveLength(2) + expect(mockAgent.messages[0]!.content[0]!).toEqual({ type: 'textBlock', text: 'Message 2' }) + expect(mockAgent.messages[1]!.content[0]!).toEqual({ type: 'textBlock', text: 'Response 2' }) + }) + }) + + describe('reduceContext - tool pair validation', () => { + it('does not trim at index where oldest message is toolResult', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 2, shouldTruncateResults: false }) + const messages = [ + new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'tool1', toolUseId: 'id-1', input: {} })], + }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'id-1', + status: 'success', + content: [new TextBlock('Result')], + }), + ], + }), + new Message({ role: 'assistant', content: [new TextBlock('Response')] }), + new Message({ role: 'user', content: [new TextBlock('Message')] }), + ] + const mockAgent = createMockAgent({ messages }) + + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) + + // Skips index 1 (toolResult) and index 2 (assistant), trims at index 3 (user) + expect(mockAgent.messages).toHaveLength(1) + expect(mockAgent.messages[0]!.content[0]!).toEqual({ type: 'textBlock', text: 'Message' }) + }) + + it('does not trim at index where oldest message is toolUse without following toolResult', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 2, shouldTruncateResults: false }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'tool1', toolUseId: 'id-1', input: {} })], + }), + new Message({ role: 'assistant', content: [new TextBlock('Response')] }), // Not a toolResult + new Message({ role: 'user', content: [new TextBlock('Message 2')] }), + ] + const mockAgent = createMockAgent({ messages }) + + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) + + // Skips index 1 (toolUse without following toolResult), skips index 2 (assistant), trims at index 3 (user) + expect(mockAgent.messages).toHaveLength(1) + expect(mockAgent.messages[0]!.content[0]!).toEqual({ type: 'textBlock', text: 'Message 2' }) + }) + + it('allows trim when oldest message is toolUse with following toolResult', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 2, shouldTruncateResults: false }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'tool1', toolUseId: 'id-1', input: {} })], + }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'id-1', + status: 'success', + content: [new TextBlock('Result')], + }), + ], + }), + new Message({ role: 'assistant', content: [new TextBlock('Response')] }), + new Message({ role: 'user', content: [new TextBlock('Message 2')] }), + ] + const mockAgent = createMockAgent({ messages }) + + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) + + // trimIndex starts at 3 (5 - 2 = 3), which is assistant 'Response' — skipped (not user). + // trimIndex 4 is user 'Message 2' — valid. + expect(mockAgent.messages).toHaveLength(1) + expect(mockAgent.messages[0]!.content[0]!).toEqual({ type: 'textBlock', text: 'Message 2' }) + }) + + it('allows trim at toolUse when toolResult immediately follows', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 4, shouldTruncateResults: false }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + new Message({ role: 'user', content: [new TextBlock('Message 2')] }), + new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'tool1', toolUseId: 'id-1', input: {} })], + }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'id-1', + status: 'success', + content: [new TextBlock('Result')], + }), + ], + }), + new Message({ role: 'assistant', content: [new TextBlock('Response 2')] }), + ] + const mockAgent = createMockAgent({ messages }) + + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) + + // trimIndex starts at 2 (6 - 4 = 2), which is user 'Message 2' — valid trim point + expect(mockAgent.messages).toHaveLength(4) + expect(mockAgent.messages[0]!.content[0]!).toEqual({ type: 'textBlock', text: 'Message 2' }) + }) + + it('allows trim when oldest message is text or other non-tool content', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 2 }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + new Message({ role: 'user', content: [new TextBlock('Message 2')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 2')] }), + ] + const mockAgent = createMockAgent({ messages }) + + await triggerContextOverflow(manager, mockAgent, new ContextWindowOverflowError('Context overflow')) + + // trimIndex starts at 2 (4 - 2 = 2), which is user 'Message 2' — valid + expect(mockAgent.messages).toHaveLength(2) + expect(mockAgent.messages[0]!.content[0]).toEqual({ type: 'textBlock', text: 'Message 2' }) + }) + + it('skips assistant message to ensure trimmed conversation starts with user', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 8 }) + const messages = Array.from( + { length: 9 }, + (_, i) => new Message({ role: i % 2 === 0 ? 'user' : 'assistant', content: [new TextBlock(`message ${i}`)] }) + ) + const mockAgent = createMockAgent({ messages }) + + await triggerSlidingWindow(manager, mockAgent) + + // Naive trim would leave assistant at index 1 as first message. + // Fix skips it so conversation starts with user at index 2. + expect(mockAgent.messages[0]!.role).toBe('user') + expect(mockAgent.messages[0]!.content[0]!).toEqual({ type: 'textBlock', text: 'message 2' }) + }) + + it('returns false when no valid trim point exists', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 1, shouldTruncateResults: false }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'id-1', + status: 'success', + content: [new TextBlock('Result')], + }), + ], + }), + ] + + const result = manager.reduce({ + agent: createMockAgent({ messages }), + model: {} as Model, + error: new ContextWindowOverflowError('Context overflow'), + }) + + expect(result).toBe(false) + }) + + it('propagates the original ContextWindowOverflowError when reduce cannot reduce further', async () => { + const manager = new SlidingWindowConversationManager({ windowSize: 1, shouldTruncateResults: false }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'id-1', + status: 'success', + content: [new TextBlock('Result')], + }), + ], + }), + ] + const mockAgent = createMockAgent({ messages }) + const originalError = new ContextWindowOverflowError('Context overflow') + + // The base class hook does not set event.retry when reduce returns false, + // so the original error propagates out of the hook chain + const event = new AfterModelCallEvent({ + agent: mockAgent, + model: {} as any, + attemptCount: 1, + error: originalError, + invocationState: {}, + }) + const pluginAgent = createMockAgent() + manager.initAgent(pluginAgent) + await invokeTrackedHook(pluginAgent, event) + + expect(event.retry).toBeUndefined() + }) + }) + + describe('helper methods', () => { + describe('findOldestMessageWithToolResults', () => { + it('returns correct index when tool results exist', () => { + const manager = new SlidingWindowConversationManager() + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'id-1', + status: 'success', + content: [new TextBlock('Result 1')], + }), + ], + }), + new Message({ role: 'assistant', content: [new TextBlock('Response')] }), + ] + + const index = (manager as any)._findOldestMessageWithToolResults(messages) + expect(index).toBe(1) + }) + + it('returns undefined when no tool results exist', () => { + const manager = new SlidingWindowConversationManager() + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + ] + + const index = (manager as any)._findOldestMessageWithToolResults(messages) + expect(index).toBeUndefined() + }) + + it('iterates forward from start', () => { + const manager = new SlidingWindowConversationManager() + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'id-1', + status: 'success', + content: [new TextBlock('Result 1')], + }), + ], + }), + new Message({ role: 'assistant', content: [new TextBlock('Response')] }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'id-2', + status: 'success', + content: [new TextBlock('Result 2')], + }), + ], + }), + ] + + const index = (manager as any)._findOldestMessageWithToolResults(messages) + // Should find the first one (index 0), not the last (index 2) + expect(index).toBe(0) + }) + }) + + describe('truncateToolResults', () => { + it('returns true when changes are made', () => { + const manager = new SlidingWindowConversationManager() + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'id-1', + status: 'success', + content: [new TextBlock('x'.repeat(500))], + }), + ], + }), + ] + + const result = (manager as any)._truncateToolResults(messages, 0) + expect(result).toBe(true) + }) + + it('returns false when already truncated', () => { + const manager = new SlidingWindowConversationManager() + const alreadyTruncated = 'A'.repeat(200) + '\n\n' + 'B'.repeat(200) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'id-1', + status: 'success', + content: [new TextBlock(alreadyTruncated)], + }), + ], + }), + ] + + const result = (manager as any)._truncateToolResults(messages, 0) + expect(result).toBe(false) + }) + + it('returns false when no tool results found', () => { + const manager = new SlidingWindowConversationManager() + const messages = [new Message({ role: 'user', content: [new TextBlock('Message')] })] + + const result = (manager as any)._truncateToolResults(messages, 0) + expect(result).toBe(false) + }) + }) + }) + + describe('reduceOnThreshold', () => { + it('trims oldest messages when compressionThreshold is exceeded', async () => { + const manager = new SlidingWindowConversationManager({ + windowSize: 4, + proactiveCompression: { compressionThreshold: 0.7 }, + }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + new Message({ role: 'user', content: [new TextBlock('Message 2')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 2')] }), + new Message({ role: 'user', content: [new TextBlock('Message 3')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 3')] }), + ] + const mockModel = { getConfig: () => ({ contextWindowLimit: 1000 }) as BaseModelConfig } as any + const mockAgent = createMockAgent({ messages }) + manager.initAgent(mockAgent) + + const event = new BeforeModelCallEvent({ + agent: mockAgent, + model: mockModel, + invocationState: {}, + projectedInputTokens: 800, + }) + await invokeTrackedHook(mockAgent, event) + + expect(mockAgent.messages.length).toBe(4) + }) + + it('does not trim when below compressionThreshold', async () => { + const manager = new SlidingWindowConversationManager({ + windowSize: 4, + proactiveCompression: { compressionThreshold: 0.7 }, + }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Message 1')] }), + new Message({ role: 'assistant', content: [new TextBlock('Response 1')] }), + ] + const mockModel = { getConfig: () => ({ contextWindowLimit: 1000 }) as BaseModelConfig } as any + const mockAgent = createMockAgent({ messages }) + manager.initAgent(mockAgent) + + const event = new BeforeModelCallEvent({ + agent: mockAgent, + model: mockModel, + invocationState: {}, + projectedInputTokens: 500, + }) + await invokeTrackedHook(mockAgent, event) + + expect(mockAgent.messages).toHaveLength(2) + }) + }) +}) diff --git a/strands-ts/src/conversation-manager/__tests__/summarizing-conversation-manager.test.ts b/strands-ts/src/conversation-manager/__tests__/summarizing-conversation-manager.test.ts new file mode 100644 index 0000000000..cf2ea07942 --- /dev/null +++ b/strands-ts/src/conversation-manager/__tests__/summarizing-conversation-manager.test.ts @@ -0,0 +1,409 @@ +import { describe, it, expect, vi } from 'vitest' +import { SummarizingConversationManager } from '../summarizing-conversation-manager.js' +import { ContextWindowOverflowError, Message, TextBlock, ToolUseBlock, ToolResultBlock } from '../../index.js' +import { AfterModelCallEvent, BeforeModelCallEvent } from '../../hooks/events.js' +import { createMockAgent, invokeTrackedHook } from '../../__fixtures__/agent-helpers.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import type { Model, BaseModelConfig } from '../../models/model.js' + +function textMsg(role: 'user' | 'assistant', text: string): Message { + return new Message({ role, content: [new TextBlock(text)] }) +} + +function makeMessages(count: number): Message[] { + return Array.from({ length: count }, (_, i) => textMsg(i % 2 === 0 ? 'user' : 'assistant', `Message ${i + 1}`)) +} + +describe('SummarizingConversationManager', () => { + describe('constructor', () => { + it('clamps summaryRatio to [0.1, 0.8]', () => { + expect((new SummarizingConversationManager({ summaryRatio: 0 }) as any)._summaryRatio).toBe(0.1) + expect((new SummarizingConversationManager({ summaryRatio: 1.0 }) as any)._summaryRatio).toBe(0.8) + }) + }) + + describe('reduce', () => { + it('summarizes oldest messages and replaces them with a user-role summary', async () => { + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'Summary of conversation' }) + + const manager = new SummarizingConversationManager({ + summaryRatio: 0.5, + preserveRecentMessages: 2, + }) + const messages = makeMessages(20) + const lastTwo = messages.slice(-2) + const mockAgent = createMockAgent({ messages }) + + const result = await manager.reduce({ + agent: mockAgent, + model: model as unknown as Model, + error: new ContextWindowOverflowError('overflow'), + }) + + expect(result).toBe(true) + // 20 * 0.5 = 10 summarized → 1 summary + 10 remaining = 11 + expect(mockAgent.messages).toHaveLength(11) + expect(mockAgent.messages[0]!.role).toBe('user') + expect(mockAgent.messages[0]!.content[0]!).toEqual({ + type: 'textBlock', + text: 'Summary of conversation', + }) + // Recent messages preserved + expect(mockAgent.messages.slice(-2)).toEqual(lastTwo) + }) + + it('uses the config model over the reduce model when provided', async () => { + const configModel = new MockMessageModel() + configModel.addTurn({ type: 'textBlock', text: 'Config model summary' }) + const reduceModel = new MockMessageModel() + reduceModel.addTurn({ type: 'textBlock', text: 'Reduce model summary' }) + + const manager = new SummarizingConversationManager({ + model: configModel as unknown as Model, + summaryRatio: 0.5, + preserveRecentMessages: 2, + }) + const messages = makeMessages(20) + const mockAgent = createMockAgent({ messages }) + + await manager.reduce({ + agent: mockAgent, + model: reduceModel as unknown as Model, + error: new ContextWindowOverflowError('overflow'), + }) + + expect(mockAgent.messages[0]!.content[0]!).toEqual({ + type: 'textBlock', + text: 'Config model summary', + }) + }) + + it('uses the config model when no reduce model is provided', async () => { + const configModel = new MockMessageModel() + configModel.addTurn({ type: 'textBlock', text: 'Config model summary' }) + + const manager = new SummarizingConversationManager({ + model: configModel as unknown as Model, + summaryRatio: 0.5, + preserveRecentMessages: 2, + }) + const messages = makeMessages(20) + const mockAgent = createMockAgent({ messages }) + + const result = await manager.reduce({ + agent: mockAgent, + model: {} as Model, + error: new ContextWindowOverflowError('overflow'), + }) + + expect(result).toBe(true) + expect(mockAgent.messages[0]!.content[0]!).toEqual({ + type: 'textBlock', + text: 'Config model summary', + }) + }) + + it('returns false when there are not enough messages to summarize', async () => { + const model = new MockMessageModel() + const manager = new SummarizingConversationManager({ + preserveRecentMessages: 10, + }) + const messages = makeMessages(8) + const mockAgent = createMockAgent({ messages }) + + const result = await manager.reduce({ + agent: mockAgent, + model: model as unknown as Model, + error: new ContextWindowOverflowError('overflow'), + }) + + expect(result).toBe(false) + expect(mockAgent.messages).toHaveLength(8) + }) + + it('rethrows model errors with the overflow error as cause', async () => { + const model = new MockMessageModel() + model.addTurn(new Error('model failed')) + + const manager = new SummarizingConversationManager({ + summaryRatio: 0.5, + preserveRecentMessages: 2, + }) + const overflowError = new ContextWindowOverflowError('overflow') + const mockAgent = createMockAgent({ messages: makeMessages(20) }) + + const thrown = await manager + .reduce({ agent: mockAgent, model: model as unknown as Model, error: overflowError }) + .catch((e: unknown) => e) + expect(thrown).toBeInstanceOf(Error) + expect((thrown as Error).message).toBe('model failed') + expect((thrown as Error).cause).toBe(overflowError) + }) + + it('wraps non-Error throw values with the overflow error as cause', async () => { + const model = new MockMessageModel() + const err = 'string error' + vi.spyOn(model, 'streamAggregated').mockImplementation(async function* () { + yield undefined as any + throw err + } as any) + + const manager = new SummarizingConversationManager({ + summaryRatio: 0.5, + preserveRecentMessages: 2, + }) + const overflowError = new ContextWindowOverflowError('overflow') + const mockAgent = createMockAgent({ messages: makeMessages(20) }) + + const thrown = await manager + .reduce({ agent: mockAgent, model: model as unknown as Model, error: overflowError }) + .catch((e: unknown) => e) + expect(thrown).toBeInstanceOf(Error) + expect((thrown as Error).message).toBe('string error') + expect((thrown as Error).cause).toBe(overflowError) + }) + + it('passes the correct message slice and system prompt to the model', async () => { + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'Summary' }) + const streamSpy = vi.spyOn(model, 'stream') + + const customPrompt = 'Custom summarization prompt' + const manager = new SummarizingConversationManager({ + summaryRatio: 0.5, + preserveRecentMessages: 2, + summarizationSystemPrompt: customPrompt, + }) + const messages = makeMessages(10) + const expectedSlice = messages.slice(0, 5) + const mockAgent = createMockAgent({ messages }) + + await manager.reduce({ + agent: mockAgent, + model: model as unknown as Model, + error: new ContextWindowOverflowError('overflow'), + }) + + expect(streamSpy).toHaveBeenCalledOnce() + const [calledMessages, calledOptions] = streamSpy.mock.calls[0]! + // First 5 messages (10 * 0.5) plus the "Please summarize" request + expect(calledMessages).toHaveLength(6) + expect(calledMessages!.slice(0, 5)).toEqual(expectedSlice) + expect(calledMessages![5]!.role).toBe('user') + expect(calledMessages![5]!.content[0]!).toEqual( + expect.objectContaining({ text: 'Please summarize this conversation.' }) + ) + expect(calledOptions).toEqual(expect.objectContaining({ systemPrompt: customPrompt })) + }) + + it('preserveRecentMessages dominates when larger than ratio allows', async () => { + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'Summary' }) + + const manager = new SummarizingConversationManager({ + summaryRatio: 0.8, + preserveRecentMessages: 18, + }) + const messages = makeMessages(20) + const mockAgent = createMockAgent({ messages }) + + const result = await manager.reduce({ + agent: mockAgent, + model: model as unknown as Model, + error: new ContextWindowOverflowError('overflow'), + }) + + expect(result).toBe(true) + // 20 * 0.8 = 16, but min(16, 20-18) = 2, so only 2 summarized + // 1 summary + 18 remaining = 19 + expect(mockAgent.messages).toHaveLength(19) + }) + }) + + describe('tool pair adjustment', () => { + it('advances split point past orphaned toolResult and toolUse boundaries', async () => { + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'Summary' }) + + const manager = new SummarizingConversationManager({ + summaryRatio: 0.3, + preserveRecentMessages: 2, + }) + + // Natural split at ~index 3 lands on a toolResult + const messages = [ + textMsg('user', 'Message 1'), + textMsg('assistant', 'Message 2'), + new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'tool1', toolUseId: 'id-1', input: {} })], + }), + new Message({ + role: 'user', + content: [new ToolResultBlock({ toolUseId: 'id-1', status: 'success', content: [new TextBlock('Result')] })], + }), + textMsg('assistant', 'Response after tool'), + ...makeMessages(8), + ] + const mockAgent = createMockAgent({ messages }) + + const result = await manager.reduce({ + agent: mockAgent, + model: model as unknown as Model, + error: new ContextWindowOverflowError('overflow'), + }) + + expect(result).toBe(true) + // After summary insertion, no remaining message should start with an orphaned toolResult + expect(mockAgent.messages[1]!.content.some((b) => b.type === 'toolResultBlock')).toBe(false) + }) + + it('throws when no valid split point exists', async () => { + const model = new MockMessageModel() + const manager = new SummarizingConversationManager({ + summaryRatio: 0.5, + preserveRecentMessages: 0, + }) + + // All messages are toolResults + const messages = Array.from( + { length: 4 }, + (_, i) => + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ toolUseId: `id-${i}`, status: 'success', content: [new TextBlock(`R${i}`)] }), + ], + }) + ) + const mockAgent = createMockAgent({ messages }) + + await expect( + manager.reduce({ + agent: mockAgent, + model: model as unknown as Model, + error: new ContextWindowOverflowError('overflow'), + }) + ).rejects.toThrow('Unable to find valid split point for summarization') + }) + }) + + describe('base class hook integration', () => { + // Two agents: pluginAgent receives the hook registration via initAgent(), + // while agent holds the messages and is carried on the event object. + it('async reduce sets retry=true through the base class await', async () => { + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'Summary' }) + + const manager = new SummarizingConversationManager({ + summaryRatio: 0.5, + preserveRecentMessages: 2, + }) + const messages = makeMessages(20) + const agent = createMockAgent({ messages }) + + const pluginAgent = createMockAgent() + manager.initAgent(pluginAgent) + const event = new AfterModelCallEvent({ + agent, + model: model as unknown as Model, + attemptCount: 1, + error: new ContextWindowOverflowError('overflow'), + invocationState: {}, + }) + await invokeTrackedHook(pluginAgent, event) + + expect(event.retry).toBe(true) + expect(agent.messages).toHaveLength(11) + }) + }) + + describe('reduceOnThreshold', () => { + it('summarizes oldest messages when compressionThreshold is exceeded', async () => { + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'Summary of conversation' }) + + const manager = new SummarizingConversationManager({ + model: model as unknown as Model, + summaryRatio: 0.5, + preserveRecentMessages: 2, + proactiveCompression: { compressionThreshold: 0.7 }, + }) + const messages = makeMessages(20) + const mockAgent = createMockAgent({ messages }) + const mockModel = { getConfig: () => ({ contextWindowLimit: 1000 }) as BaseModelConfig } as any + + manager.initAgent(mockAgent) + + const event = new BeforeModelCallEvent({ + agent: mockAgent, + model: mockModel, + invocationState: {}, + projectedInputTokens: 800, + }) + await invokeTrackedHook(mockAgent, event) + + // 20 * 0.5 = 10 summarized → 1 summary + 10 remaining = 11 + expect(mockAgent.messages).toHaveLength(11) + expect(mockAgent.messages[0]!.role).toBe('user') + expect(mockAgent.messages[0]!.content[0]!).toEqual({ + type: 'textBlock', + text: 'Summary of conversation', + }) + }) + + it('does not summarize when below compressionThreshold', async () => { + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'Summary' }) + + const manager = new SummarizingConversationManager({ + model: model as unknown as Model, + proactiveCompression: { compressionThreshold: 0.7 }, + }) + const messages = makeMessages(20) + const mockAgent = createMockAgent({ messages }) + const mockModel = { getConfig: () => ({ contextWindowLimit: 1000 }) as BaseModelConfig } as any + + manager.initAgent(mockAgent) + + const event = new BeforeModelCallEvent({ + agent: mockAgent, + model: mockModel, + invocationState: {}, + projectedInputTokens: 500, + }) + await invokeTrackedHook(mockAgent, event) + + expect(mockAgent.messages).toHaveLength(20) + }) + + it('returns false and does not throw when summarization fails', async () => { + const model = new MockMessageModel() + model.addTurn(new Error('model failed')) + + const manager = new SummarizingConversationManager({ + model: model as unknown as Model, + summaryRatio: 0.5, + preserveRecentMessages: 2, + proactiveCompression: { compressionThreshold: 0.7 }, + }) + const messages = makeMessages(20) + const mockAgent = createMockAgent({ messages }) + const mockModel = { getConfig: () => ({ contextWindowLimit: 1000 }) as BaseModelConfig } as any + + manager.initAgent(mockAgent) + + const event = new BeforeModelCallEvent({ + agent: mockAgent, + model: mockModel, + invocationState: {}, + projectedInputTokens: 800, + }) + + // Should not throw — reduceOnThreshold is best-effort + await invokeTrackedHook(mockAgent, event) + expect(mockAgent.messages).toHaveLength(20) + }) + }) +}) diff --git a/strands-ts/src/conversation-manager/conversation-manager.ts b/strands-ts/src/conversation-manager/conversation-manager.ts new file mode 100644 index 0000000000..078e654c7e --- /dev/null +++ b/strands-ts/src/conversation-manager/conversation-manager.ts @@ -0,0 +1,213 @@ +/** + * Abstract base class for conversation history management. + * + * This module defines the ConversationManager abstraction, which provides a + * domain-specific interface for managing an agent's conversation context. + */ + +import type { Plugin } from '../plugins/plugin.js' +import type { LocalAgent } from '../types/agent.js' +import { AfterModelCallEvent, BeforeModelCallEvent } from '../hooks/events.js' +import { ContextWindowOverflowError } from '../errors.js' +import type { Model } from '../models/model.js' +import { logger } from '../logging/logger.js' +import { warnOnce } from '../logging/warn-once.js' + +/** Default compression threshold ratio. */ +const DEFAULT_COMPRESSION_THRESHOLD = 0.7 + +/** Default context window limit fallback when the model doesn't report one. */ +const DEFAULT_CONTEXT_WINDOW_LIMIT = 200_000 + +/** + * Options passed to {@link ConversationManager.reduce}. + * + * When `error` is set, this is a reactive overflow recovery call — the implementation + * MUST remove enough history for the next model call to succeed. + * + * When `error` is undefined, this is a proactive compression call — best-effort reduction + * to avoid hitting the context window limit. + */ +export type ConversationManagerReduceOptions = { + /** + * The agent instance. Mutate `agent.messages` in place to reduce history. + */ + agent: LocalAgent + + /** + * The model instance. Used by conversation managers that perform model-based + * reduction (e.g. summarization). + */ + model: Model + + /** + * The {@link ContextWindowOverflowError} that triggered this call, or `undefined` + * for proactive compression calls. + * + * When set, `reduce` MUST remove enough history for the next model call to succeed, + * or this error will propagate out of the agent loop uncaught. + * + * When undefined, `reduce` is best-effort — errors are swallowed and the model call + * proceeds regardless. + */ + error?: ContextWindowOverflowError +} + +/** + * Configuration for proactive compression when passed as an object. + */ +export type ProactiveCompressionConfig = { + /** + * Ratio of context window usage that triggers proactive compression. + * Value between 0 (exclusive) and 1 (inclusive). + * Defaults to 0.7 (compress when 70% of the context window is used). + */ + compressionThreshold: number +} + +/** + * Configuration options for the ConversationManager base class. + */ +export type ConversationManagerOptions = { + /** + * Enable proactive context compression before the model call. + * + * - `true`: compress when 70% of the context window is used (default threshold). + * - `{ compressionThreshold: number }`: compress at the specified ratio (0, 1]. + * - `false` or omitted: disabled, only reactive overflow recovery is used. + */ + proactiveCompression?: boolean | ProactiveCompressionConfig +} + +/** + * Abstract base class for conversation history management strategies. + * + * The primary responsibility of a ConversationManager is overflow recovery: when the + * model returns a {@link ContextWindowOverflowError}, {@link ConversationManager.reduce} + * is called with `error` set and MUST reduce the history enough for the next model call + * to succeed. If `reduce` returns `false` (no reduction performed), the error propagates + * out of the agent loop uncaught. This makes `reduce` a critical operation — + * implementations must be able to make meaningful progress when called with `error` set. + * + * Subclasses can enable proactive compression by passing `proactiveCompression` in the + * options object to the base constructor. When enabled, the base class registers a + * `BeforeModelCallEvent` hook that checks projected input tokens against the model's + * context window limit and calls `reduce` (without `error`) when the threshold is exceeded. + * + * @example + * ```typescript + * class Last10MessagesManager extends ConversationManager { + * readonly name = 'my:last-10-messages' + * + * reduce({ agent }: ReduceOptions): boolean { + * if (agent.messages.length <= 10) return false + * agent.messages.splice(0, agent.messages.length - 10) + * return true + * } + * } + * ``` + */ +export abstract class ConversationManager implements Plugin { + /** + * A stable string identifier for this conversation manager. + */ + abstract readonly name: string + + protected readonly _compressionThreshold: number | undefined + + /** + * @param options - Configuration options for the conversation manager. + */ + constructor(options?: ConversationManagerOptions) { + const proactiveCompression = options?.proactiveCompression + const threshold = + proactiveCompression === true + ? DEFAULT_COMPRESSION_THRESHOLD + : proactiveCompression + ? proactiveCompression.compressionThreshold + : undefined + + if (threshold !== undefined && (threshold <= 0 || threshold > 1)) { + throw new Error(`compressionThreshold must be between 0 (exclusive) and 1 (inclusive), got ${threshold}`) + } + this._compressionThreshold = threshold + } + + /** + * Reduce the conversation history. + * + * Called in two scenarios: + * 1. **Reactive** (error set): A {@link ContextWindowOverflowError} occurred. The implementation + * MUST remove enough history for the next model call to succeed. Returning `false` means no + * reduction was possible, and the error will propagate out of the agent loop. + * 2. **Proactive** (error undefined): The compression threshold was exceeded. This is best-effort — + * returning `false` or throwing is acceptable; the model call proceeds regardless. + * + * Implementations should mutate `agent.messages` in place and return `true` if any reduction + * was performed, `false` otherwise. + * + * @param options - The reduction options + * @returns `true` if the history was reduced, `false` otherwise. + * May return a `Promise` for implementations that need async I/O (e.g. model calls). + */ + abstract reduce(options: ConversationManagerReduceOptions): boolean | Promise + + /** + * Initialize the conversation manager with the agent instance. + * + * Registers two hooks: + * 1. `AfterModelCallEvent`: Overflow recovery — when a {@link ContextWindowOverflowError} occurs, + * calls {@link ConversationManager.reduce} with `error` set and retries if reduction succeeded. + * 2. `BeforeModelCallEvent`: Proactive compression — when projected input tokens exceed the + * configured compression threshold, calls {@link ConversationManager.reduce} without `error`. + * The hook is always registered but only acts when proactive compression is enabled. + * + * Subclasses that override `initAgent` MUST call `super.initAgent(agent)` to + * preserve overflow recovery and proactive compression behavior. + * + * @param agent - The agent to register hooks with + */ + initAgent(agent: LocalAgent): void { + // Reactive overflow recovery + agent.addHook(AfterModelCallEvent, async (event) => { + if (event.error instanceof ContextWindowOverflowError) { + if (await this.reduce({ agent: event.agent, model: event.model, error: event.error })) { + event.retry = true + } + } + }) + + // Proactive compression — always subscribe, check threshold in the handler + agent.addHook(BeforeModelCallEvent, async (event) => { + if (this._compressionThreshold === undefined) { + return + } + + let contextWindowLimit = event.model.getConfig().contextWindowLimit + if (contextWindowLimit === undefined) { + contextWindowLimit = DEFAULT_CONTEXT_WINDOW_LIMIT + warnOnce( + logger, + `conversation_manager=<${this.name}> | contextWindowLimit is not set on the model, using default of ${DEFAULT_CONTEXT_WINDOW_LIMIT} | set contextWindowLimit in your model config for accurate proactive compression` + ) + } + + if (event.projectedInputTokens === undefined) { + return + } + + const ratio = event.projectedInputTokens / contextWindowLimit + if (ratio >= this._compressionThreshold) { + logger.debug( + `projected_tokens=<${event.projectedInputTokens}>, limit=<${contextWindowLimit}>, ratio=<${ratio.toFixed(2)}>, compression_threshold=<${this._compressionThreshold}> | compression threshold exceeded, reducing context` + ) + // Proactive compression is best-effort: swallow errors so the model call can still proceed. + try { + await this.reduce({ agent: event.agent, model: event.model }) + } catch (e) { + logger.warn(`conversation_manager=<${this.name}> | proactive compression failed, continuing | error=<${e}>`) + } + } + }) + } +} diff --git a/strands-ts/src/conversation-manager/index.ts b/strands-ts/src/conversation-manager/index.ts new file mode 100644 index 0000000000..9ebbffd703 --- /dev/null +++ b/strands-ts/src/conversation-manager/index.ts @@ -0,0 +1,21 @@ +/** + * Conversation Manager exports. + * + * This module exports conversation manager implementations. + */ + +export { + ConversationManager, + type ProactiveCompressionConfig, + type ConversationManagerReduceOptions as ReduceOptions, + type ConversationManagerOptions, +} from './conversation-manager.js' +export { NullConversationManager } from './null-conversation-manager.js' +export { + SlidingWindowConversationManager, + type SlidingWindowConversationManagerConfig, +} from './sliding-window-conversation-manager.js' +export { + SummarizingConversationManager, + type SummarizingConversationManagerConfig, +} from './summarizing-conversation-manager.js' diff --git a/strands-ts/src/conversation-manager/null-conversation-manager.ts b/strands-ts/src/conversation-manager/null-conversation-manager.ts new file mode 100644 index 0000000000..f773df4059 --- /dev/null +++ b/strands-ts/src/conversation-manager/null-conversation-manager.ts @@ -0,0 +1,31 @@ +/** + * Null implementation of conversation management. + * + * This module provides a no-op conversation manager that does not modify + * the conversation history. Useful for testing and scenarios where conversation + * management is handled externally. + */ + +import { ConversationManager, type ConversationManagerReduceOptions } from './conversation-manager.js' + +/** + * A no-op conversation manager that does not modify the conversation history. + * + * Does not register any proactive hooks. Overflow errors will not be retried + * since `reduce` always returns `false`. + */ +export class NullConversationManager extends ConversationManager { + /** + * Unique identifier for this conversation manager. + */ + readonly name = 'strands:null-conversation-manager' + + /** + * No-op reduction — never modifies the conversation history. + * + * @returns `false` always + */ + reduce(_args: ConversationManagerReduceOptions): boolean { + return false + } +} diff --git a/strands-ts/src/conversation-manager/sliding-window-conversation-manager.ts b/strands-ts/src/conversation-manager/sliding-window-conversation-manager.ts new file mode 100644 index 0000000000..a32f4282b8 --- /dev/null +++ b/strands-ts/src/conversation-manager/sliding-window-conversation-manager.ts @@ -0,0 +1,463 @@ +/** + * Sliding window conversation history management. + * + * This module provides a sliding window strategy for managing conversation history + * that preserves tool usage pairs and avoids invalid window states. + */ + +import { Message, TextBlock, ToolResultBlock, type ToolResultContent } from '../types/messages.js' +import { DocumentBlock, ImageBlock, VideoBlock } from '../types/media.js' +import type { LocalAgent } from '../types/agent.js' +import { AfterInvocationEvent } from '../hooks/events.js' +import { + ConversationManager, + type ProactiveCompressionConfig, + type ConversationManagerReduceOptions, +} from './conversation-manager.js' +import { logger } from '../logging/logger.js' + +const PRESERVE_CHARS = 200 +// Max plausible marker length, including newlines. Used as the minimum reduction +// a re-truncation would need to produce in order to be worth running. +const MIN_TRUNCATION_GAIN = 50 +// Text payloads at or below this length aren't worth truncating: the savings +// would be smaller than the marker itself, and already-truncated output (which +// lands just above `2 * PRESERVE_CHARS`) falls under this threshold so a +// second pass is a natural no-op. +const TRUNCATION_THRESHOLD = 2 * PRESERVE_CHARS + MIN_TRUNCATION_GAIN + +/** + * Build a short textual stand-in for an image block, used when truncating tool + * results. The placeholder identifies the image format and its source kind + * (bytes/url/s3) so the model can reason about what was dropped. For inline + * bytes the size is included; URL and S3 sources only report the kind since + * their byte count isn't known locally. + */ +function imagePlaceholder(image: ImageBlock): string { + const source = image.source + if (source.type === 'imageSourceBytes') { + return `[image: ${image.format}, source: bytes, ${source.bytes.byteLength} bytes]` + } + if (source.type === 'imageSourceUrl') { + return `[image: ${image.format}, source: url]` + } + return `[image: ${image.format}, source: s3]` +} + +/** + * Build a short textual stand-in for a video block. Binary payloads can't be + * partially inspected, so videos are replaced wholesale. The placeholder + * reports format and source kind; byte count is included for inline bytes. + */ +function videoPlaceholder(video: VideoBlock): string { + const source = video.source + if (source.type === 'videoSourceBytes') { + return `[video: ${video.format}, source: bytes, ${source.bytes.byteLength} bytes]` + } + return `[video: ${video.format}, source: s3]` +} + +/** + * Build a short textual stand-in for a document block with a binary or remote + * source. Text-based document sources (text / content) are truncated in place + * instead of replaced, so this is only called for bytes / s3. + */ +function documentPlaceholder(doc: DocumentBlock): string { + const source = doc.source + if (source.type === 'documentSourceBytes') { + return `[document: ${doc.name}, ${doc.format}, source: bytes, ${source.bytes.byteLength} bytes]` + } + return `[document: ${doc.name}, ${doc.format}, source: s3]` +} + +/** + * Build a short textual stand-in for a JSON block. The serialized length is + * reported so the model knows how much was dropped; truncating JSON + * mid-structure would produce invalid output, so the whole block is replaced. + */ +function jsonPlaceholder(serializedLength: number): string { + return `[json: ${serializedLength} chars]` +} + +/** + * Configuration for the sliding window conversation manager. + */ +export type SlidingWindowConversationManagerConfig = { + /** + * Maximum number of messages to keep in the conversation history. + * Defaults to 40 messages. + */ + windowSize?: number + + /** + * Whether to truncate tool results when a message is too large for the model's context window. + * Defaults to true. + */ + shouldTruncateResults?: boolean + + /** + * Enable proactive context compression before the model call. + * + * - `true`: compress when 70% of the context window is used (default threshold). + * - `{ compressionThreshold: number }`: compress at the specified ratio (0, 1]. + * - `false` or omitted: disabled, only reactive overflow recovery is used. + */ + proactiveCompression?: boolean | ProactiveCompressionConfig +} + +/** + * Implements a sliding window strategy for managing conversation history. + * + * This class handles the logic of maintaining a conversation window that preserves + * tool usage pairs and avoids invalid window states. When the message count exceeds + * the window size, it will either truncate large tool results or remove the oldest + * messages while ensuring tool use/result pairs remain valid. + * + * Registers hooks for: + * - AfterInvocationEvent: Applies sliding window management after each invocation + * - AfterModelCallEvent: Reduces context on overflow errors and requests retry (via super) + * - BeforeModelCallEvent: Proactive compression when threshold is exceeded (via super) + */ +export class SlidingWindowConversationManager extends ConversationManager { + private readonly _windowSize: number + private readonly _shouldTruncateResults: boolean + + /** + * Unique identifier for this conversation manager. + */ + readonly name = 'strands:sliding-window-conversation-manager' + + /** + * Initialize the sliding window conversation manager. + * + * @param config - Configuration options for the sliding window manager. + */ + constructor(config?: SlidingWindowConversationManagerConfig) { + super(config) + this._windowSize = config?.windowSize ?? 40 + this._shouldTruncateResults = config?.shouldTruncateResults ?? true + } + + /** + * Initialize the plugin by registering hooks with the agent. + * + * Registers: + * - AfterInvocationEvent callback to apply sliding window management + * - AfterModelCallEvent callback to handle context overflow and request retry (via super) + * - BeforeModelCallEvent callback for proactive compression (via super) + * + * @param agent - The agent to register hooks with + */ + public override initAgent(agent: LocalAgent): void { + super.initAgent(agent) + + agent.addHook(AfterInvocationEvent, (event) => { + this._applyManagement(event.agent.messages) + }) + } + + /** + * Reduce the conversation history. + * + * When `error` is set (reactive overflow recovery), attempts to truncate large tool results + * first before falling back to message trimming. + * + * When `error` is undefined (proactive compression), only trims messages without attempting + * tool result truncation. + * + * @param options - The reduction options + * @returns `true` if the history was reduced, `false` otherwise + */ + reduce({ agent, error }: ConversationManagerReduceOptions): boolean { + return this._reduceContext(agent.messages, error) + } + + /** + * Apply the sliding window to the messages array to maintain a manageable history size. + * + * Called after every agent invocation. No-op if within the window size. + * + * @param messages - The message array to manage. Modified in-place. + */ + private _applyManagement(messages: Message[]): void { + if (messages.length <= this._windowSize) { + return + } + + this._reduceContext(messages, undefined) + } + + /** + * Trim the oldest messages to reduce the conversation context size. + * + * The method handles special cases where trimming the messages leads to: + * - toolResult with no corresponding toolUse + * - toolUse with no corresponding toolResult + * + * The strategy is: + * 1. First, attempt to truncate large tool results if shouldTruncateResults is true + * 2. If truncation is not possible or doesn't help, trim oldest messages + * 3. When trimming, skip invalid trim points (toolResult at start, or toolUse without following toolResult) + * + * @param messages - The message array to reduce. Modified in-place. + * @param _error - The error that triggered the context reduction, if any. + * @returns `true` if any reduction occurred, `false` otherwise. + */ + private _reduceContext(messages: Message[], _error?: Error): boolean { + // Only truncate tool results when handling a context overflow error, not for window size enforcement + const oldestMessageIdxWithToolResults = this._findOldestMessageWithToolResults(messages) + if (_error && oldestMessageIdxWithToolResults !== undefined && this._shouldTruncateResults) { + const resultsTruncated = this._truncateToolResults(messages, oldestMessageIdxWithToolResults) + if (resultsTruncated) { + return true + } + } + + // Try to trim messages when tool result cannot be truncated anymore + // If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size + let trimIndex = messages.length <= this._windowSize ? 2 : messages.length - this._windowSize + + // Find the next valid trim point that: + // 1. Starts with a user message (required by some models) + // 2. Does not start with an orphaned toolResult + // 3. Does not start with a toolUse unless its toolResult immediately follows + while (trimIndex < messages.length) { + const oldestMessage = messages[trimIndex] + if (!oldestMessage) { + break + } + + // Must start with a user message + if (oldestMessage.role !== 'user') { + trimIndex++ + continue + } + + // Cannot start with an orphaned toolResult + const hasToolResult = oldestMessage.content.some((block) => block.type === 'toolResultBlock') + if (hasToolResult) { + trimIndex++ + continue + } + + // toolUse is only valid if the next message is its toolResult + const hasToolUse = oldestMessage.content.some((block) => block.type === 'toolUseBlock') + if (hasToolUse) { + const nextMessage = messages[trimIndex + 1] + const nextHasToolResult = nextMessage && nextMessage.content.some((block) => block.type === 'toolResultBlock') + if (!nextHasToolResult) { + trimIndex++ + continue + } + } + + // Valid trim point found + break + } + + // If no valid trim point was found, return false and let the caller handle it. + // When windowSize is 0, trimIndex === messages.length is expected (remove all), so allow it through. + if (trimIndex > messages.length || (trimIndex === messages.length && this._windowSize > 0)) { + logger.warn( + `window_size=<${this._windowSize}>, messages=<${messages.length}> | unable to trim conversation context, no valid trim point found` + ) + return false + } + + // trimIndex is guaranteed to be < messages.length here, so splice always removes at least one message + messages.splice(0, trimIndex) + return true + } + + /** + * Apply head/tail truncation to a string if it exceeds the size threshold. + * + * Returns the truncated form (first {@link PRESERVE_CHARS} + marker + last + * {@link PRESERVE_CHARS}) when the input exceeds {@link TRUNCATION_THRESHOLD}, + * otherwise `undefined`. + */ + private _truncateLongText(text: string): string | undefined { + if (text.length <= TRUNCATION_THRESHOLD) { + return undefined + } + const prefix = text.slice(0, PRESERVE_CHARS) + const suffix = text.slice(-PRESERVE_CHARS) + const removed = text.length - 2 * PRESERVE_CHARS + return `${prefix}\n\n${suffix}` + } + + /** + * Truncate tool result content in a message to reduce context size. + * + * Rule: preserve head/tail when the payload is plain-text-shaped; replace + * wholesale when it's binary or remote. Specifically: + * - Text blocks: partial head/tail truncation if over threshold. + * - Image, Video blocks: wholesale replacement with a textual placeholder. + * - Document blocks with bytes/s3 source: wholesale replacement. + * - Document blocks with text source: partial truncation of the inner text. + * - Document blocks with content source (TextBlock[]): partial truncation of + * each nested block. + * - JSON blocks: wholesale replacement if serialized length is over threshold; + * mid-structure truncation would produce invalid JSON. + * + * The tool result `status` and `error` fields are preserved. + * + * @param messages - The conversation message history. + * @param msgIdx - Index of the message containing tool results to truncate. + * @returns True if any changes were made to the message, false otherwise. + */ + private _truncateToolResults(messages: Message[], msgIdx: number): boolean { + if (msgIdx >= messages.length || msgIdx < 0) { + return false + } + + const message = messages[msgIdx] + if (!message) { + return false + } + + let changesMade = false + const newContent = message.content.map((block) => { + if (block.type !== 'toolResultBlock') { + return block + } + + const toolResultBlock = block as ToolResultBlock + const newItems: ToolResultContent[] = [] + let itemChanged = false + + for (const item of toolResultBlock.content) { + if (item.type === 'imageBlock') { + newItems.push(new TextBlock(imagePlaceholder(item))) + itemChanged = true + continue + } + + if (item.type === 'videoBlock') { + newItems.push(new TextBlock(videoPlaceholder(item))) + itemChanged = true + continue + } + + if (item.type === 'documentBlock') { + const source = item.source + if (source.type === 'documentSourceBytes' || source.type === 'documentSourceS3Location') { + newItems.push(new TextBlock(documentPlaceholder(item))) + itemChanged = true + continue + } + if (source.type === 'documentSourceText') { + const truncated = this._truncateLongText(source.text) + if (truncated !== undefined) { + newItems.push( + new DocumentBlock({ + name: item.name, + format: item.format, + source: { text: truncated }, + ...(item.citations !== undefined ? { citations: item.citations } : {}), + ...(item.context !== undefined ? { context: item.context } : {}), + }) + ) + itemChanged = true + continue + } + } + if (source.type === 'documentSourceContentBlock') { + let nestedChanged = false + const newContentBlocks = source.content.map((nested) => { + const truncated = this._truncateLongText(nested.text) + if (truncated !== undefined) { + nestedChanged = true + return new TextBlock(truncated) + } + return nested + }) + if (nestedChanged) { + newItems.push( + new DocumentBlock({ + name: item.name, + format: item.format, + source: { content: newContentBlocks }, + ...(item.citations !== undefined ? { citations: item.citations } : {}), + ...(item.context !== undefined ? { context: item.context } : {}), + }) + ) + itemChanged = true + continue + } + } + newItems.push(item) + continue + } + + if (item.type === 'jsonBlock') { + const serializedLength = JSON.stringify(item.json).length + if (serializedLength > TRUNCATION_THRESHOLD) { + newItems.push(new TextBlock(jsonPlaceholder(serializedLength))) + itemChanged = true + continue + } + newItems.push(item) + continue + } + + if (item.type === 'textBlock') { + const truncated = this._truncateLongText(item.text) + if (truncated !== undefined) { + newItems.push(new TextBlock(truncated)) + itemChanged = true + continue + } + } + + newItems.push(item) + } + + if (!itemChanged) { + return block + } + + changesMade = true + return new ToolResultBlock({ + toolUseId: toolResultBlock.toolUseId, + status: toolResultBlock.status, + content: newItems, + ...(toolResultBlock.error !== undefined ? { error: toolResultBlock.error } : {}), + }) + }) + + if (!changesMade) { + return false + } + + messages[msgIdx] = new Message({ + role: message.role, + content: newContent, + }) + + return true + } + + /** + * Find the index of the oldest message containing tool results. + * + * Truncation targets the least-recent tool result first so the most relevant + * recent context is preserved as long as possible. + * + * @param messages - The conversation message history. + * @returns Index of the oldest message with tool results, or undefined if no such message exists. + */ + private _findOldestMessageWithToolResults(messages: Message[]): number | undefined { + for (let idx = 0; idx < messages.length; idx++) { + const currentMessage = messages[idx]! + + const hasToolResult = currentMessage.content.some((block) => block.type === 'toolResultBlock') + + if (hasToolResult) { + return idx + } + } + + return undefined + } +} diff --git a/strands-ts/src/conversation-manager/summarizing-conversation-manager.ts b/strands-ts/src/conversation-manager/summarizing-conversation-manager.ts new file mode 100644 index 0000000000..710311a9a2 --- /dev/null +++ b/strands-ts/src/conversation-manager/summarizing-conversation-manager.ts @@ -0,0 +1,263 @@ +/** + * Summarization-based conversation history management. + * + * This module provides a conversation manager that summarizes older messages + * when the context window overflows, preserving important information rather + * than simply discarding it. + */ + +import { Message, TextBlock } from '../types/messages.js' +import type { LocalAgent } from '../types/agent.js' +import { + ConversationManager, + type ProactiveCompressionConfig, + type ConversationManagerReduceOptions, +} from './conversation-manager.js' +import { logger } from '../logging/logger.js' +import { normalizeError } from '../errors.js' +import type { Model } from '../models/model.js' + +const DEFAULT_SUMMARIZATION_PROMPT = `You are a conversation summarizer. Provide a concise summary of the conversation \ +history. + +Format Requirements: +- You MUST create a structured and concise summary in bullet-point format. +- You MUST NOT respond conversationally. +- You MUST NOT address the user directly. +- You MUST NOT comment on tool availability. + +Assumptions: +- You MUST NOT assume tool executions failed unless otherwise stated. + +Task: +Your task is to create a structured summary document: +- It MUST contain bullet points with key topics and questions covered +- It MUST contain bullet points for all significant tools executed and their results +- It MUST contain bullet points for any code or technical information shared +- It MUST contain a section of key insights gained +- It MUST format the summary in the third person + +Example format: + +## Conversation Summary +* Topic 1: Key information +* Topic 2: Key information + +## Tools Executed +* Tool X: Result Y` + +/** + * Configuration for the summarization conversation manager. + */ +export type SummarizingConversationManagerConfig = { + /** + * Model to use for generating summaries. When provided, overrides the model + * attached to the agent. Useful when you want to use a different model than + * the one attached to the agent. + */ + model?: Model + + /** + * Ratio of messages to summarize when context overflow occurs. + * Value is clamped to [0.1, 0.8]. Defaults to 0.3 (summarize 30% of oldest messages). + */ + summaryRatio?: number + + /** + * Minimum number of recent messages to always keep. + * Defaults to 10. + */ + preserveRecentMessages?: number + + /** + * Custom system prompt for summarization. If not provided, uses a default + * prompt that produces structured bullet-point summaries. + */ + summarizationSystemPrompt?: string + + /** + * Enable proactive context compression before the model call. + * + * - `true`: compress when 70% of the context window is used (default threshold). + * - `{ compressionThreshold: number }`: compress at the specified ratio (0, 1]. + * - `false` or omitted: disabled, only reactive overflow recovery is used. + */ + proactiveCompression?: boolean | ProactiveCompressionConfig +} + +/** + * Implements a summarization strategy for managing conversation history. + * + * When a {@link ContextWindowOverflowError} occurs, this manager summarizes + * the oldest messages using a model call and replaces them with a single + * summary message, preserving context that would otherwise be lost. + */ +export class SummarizingConversationManager extends ConversationManager { + readonly name = 'strands:summarizing-conversation-manager' + + private readonly _model: Model | undefined + private readonly _summaryRatio: number + private readonly _preserveRecentMessages: number + private readonly _summarizationSystemPrompt: string + + constructor(config?: SummarizingConversationManagerConfig) { + super(config) + this._model = config?.model + // clamped [0.1, 0.8] + this._summaryRatio = Math.max(0.1, Math.min(0.8, config?.summaryRatio ?? 0.3)) + this._preserveRecentMessages = config?.preserveRecentMessages ?? 10 + this._summarizationSystemPrompt = config?.summarizationSystemPrompt ?? DEFAULT_SUMMARIZATION_PROMPT + } + + /** + * Reduce the conversation history by summarizing older messages. + * + * When `error` is set (reactive overflow recovery), summarization failure is rethrown + * with the original error as cause — the agent loop must not proceed with an overflow. + * + * When `error` is undefined (proactive compression), summarization failure is logged + * and returns `false` — the model call proceeds regardless. + * + * @param options - The reduction options + * @returns `true` if the history was reduced, `false` otherwise + */ + async reduce({ agent, model, error }: ConversationManagerReduceOptions): Promise { + try { + return await this._summarizeOldest(agent, this._model ?? model) + } catch (summarizationError) { + if (error) { + // Reactive: rethrow so the ContextWindowOverflowError propagates + logger.error(`error=<${summarizationError}> | summarization failed`) + const wrapped = normalizeError(summarizationError) + wrapped.cause = error + throw wrapped + } + // Proactive: best-effort, swallow errors so the model call can still proceed. + logger.warn(`error=<${summarizationError}> | proactive summarization failed, continuing`) + return false + } + } + + /** + * Summarize the oldest messages and replace them with a summary. + * + * @param agent - The agent instance + * @param model - The model to use for summarization + * @returns `true` if the history was reduced, `false` otherwise + */ + private async _summarizeOldest(agent: LocalAgent, model: Model): Promise { + const messages = agent.messages + + // Calculate how many messages to summarize + let messagesToSummarizeCount = Math.max(1, Math.floor(messages.length * this._summaryRatio)) + + // Don't touch recent messages + messagesToSummarizeCount = Math.min(messagesToSummarizeCount, messages.length - this._preserveRecentMessages) + + if (messagesToSummarizeCount <= 0) { + logger.warn( + `preserve_recent=<${this._preserveRecentMessages}>, messages=<${messages.length}> | insufficient messages for summarization` + ) + return false + } + + // Adjust split point to avoid breaking tool use/result pairs + messagesToSummarizeCount = this._adjustSplitPointForToolPairs(messages, messagesToSummarizeCount) + + const messagesToSummarize = messages.slice(0, messagesToSummarizeCount) + + // Generate summary via model call + const summaryMessage = await this._generateSummary(messagesToSummarize, model) + + // Replace summarized messages with the summary + messages.splice(0, messagesToSummarizeCount, summaryMessage) + + return true + } + + /** + * Generate a summary of the provided messages by calling the model directly. + * + * @param messagesToSummarize - The messages to summarize + * @returns A user-role message containing the summary + */ + private async _generateSummary(messagesToSummarize: Message[], model: Model): Promise { + const summarizationMessages = [ + ...messagesToSummarize, + new Message({ + role: 'user', + content: [new TextBlock('Please summarize this conversation.')], + }), + ] + + const stream = model.streamAggregated(summarizationMessages, { + systemPrompt: this._summarizationSystemPrompt, + }) + + // Manual .next() loop is required: streamAggregated returns its final result + // as the generator return value (done:true), which for-await-of discards. + let result: Awaited> | undefined + for (;;) { + result = await stream.next() + if (result.done) break + } + + if (!result?.done || !result.value) { + throw new Error('Failed to generate summary: no response from model') + } + + // Return the summary as a user-role message so it's valid as conversation history + return new Message({ + role: 'user', + content: result.value.message.content, + }) + } + + /** + * Adjust the split point to avoid breaking tool use/result pairs. + * + * Walks the split point forward until the message at that position is neither + * an orphaned toolResult nor a toolUse without an immediately following toolResult. + * + * @param messages - The full message array + * @param splitPoint - The initially calculated split point + * @returns The adjusted split point + * @throws If no valid split point can be found + */ + private _adjustSplitPointForToolPairs(messages: Message[], splitPoint: number): number { + if (splitPoint >= messages.length) { + return splitPoint + } + + while (splitPoint < messages.length) { + const message = messages[splitPoint]! + + // Can't leave an orphaned toolResult at the start + const hasToolResult = message.content.some((block) => block.type === 'toolResultBlock') + if (hasToolResult) { + splitPoint++ + continue + } + + // A toolUse is only valid at the boundary if the next message is its toolResult + const hasToolUse = message.content.some((block) => block.type === 'toolUseBlock') + if (hasToolUse) { + const nextMessage = messages[splitPoint + 1] + const nextHasToolResult = nextMessage?.content.some((block) => block.type === 'toolResultBlock') + if (!nextHasToolResult) { + splitPoint++ + continue + } + } + + break + } + + // If we walked past all messages, no valid split point exists + if (splitPoint >= messages.length) { + throw new Error('Unable to find valid split point for summarization') + } + + return splitPoint + } +} diff --git a/strands-ts/src/errors.ts b/strands-ts/src/errors.ts new file mode 100644 index 0000000000..f721e0983b --- /dev/null +++ b/strands-ts/src/errors.ts @@ -0,0 +1,244 @@ +/** + * Error types for the Strands Agents TypeScript SDK. + * + * These error classes represent specific error conditions that can occur + * during agent execution and model provider interactions. + */ + +import type { Message } from './types/messages.js' +import type { JSONValue } from './types/json.js' + +/** + * Base exception class for all model-related errors. + * + * This class serves as the common base type for errors that originate from + * model provider interactions. By catching ModelError, consumers can handle + * all model-related errors uniformly while still having access to specific + * error types through instanceof checks. + */ +export class ModelError extends Error { + /** + * Creates a new ModelError. + * + * @param message - Error message describing the model error + * @param options - Optional error options including the cause + */ + constructor(message: string, options?: { cause?: unknown }) { + super(message, options) + this.name = 'ModelError' + } +} + +/** + * Error thrown when input exceeds the model's context window. + * + * This error indicates that the combined length of the input (prompt, messages, + * system prompt, and tool definitions) exceeds the maximum context window size + * supported by the model. + */ +export class ContextWindowOverflowError extends ModelError { + /** + * Creates a new ContextWindowOverflowError. + * + * @param message - Error message describing the context overflow + */ + constructor(message: string) { + super(message) + this.name = 'ContextWindowOverflowError' + } +} + +/** + * Error thrown when the model reaches its maximum token limit during generation. + * + * This error indicates that the model stopped generating content because it reached + * the maximum number of tokens allowed for the response. This is an unrecoverable + * state that requires intervention, such as reducing the input size or adjusting + * the max tokens parameter. + */ +export class MaxTokensError extends ModelError { + /** + * The partial assistant message that was generated before hitting the token limit. + * This can be useful for understanding what the model was trying to generate. + */ + public readonly partialMessage: Message + + /** + * Creates a new MaxTokensError. + * + * @param message - Error message describing the max tokens condition + * @param partialMessage - The partial assistant message generated before the limit + */ + constructor(message: string, partialMessage: Message) { + super(message) + this.name = 'MaxTokensError' + this.partialMessage = partialMessage + } +} + +/** + * Error thrown when attempting to serialize a value that is not JSON-serializable. + * + * This error indicates that a value contains non-serializable types such as functions, + * symbols, or undefined values that cannot be converted to JSON. + */ +export class JsonValidationError extends Error { + /** + * Creates a new JsonValidationError. + * + * @param message - Error message describing the validation failure + */ + constructor(message: string) { + super(message) + this.name = 'JsonValidationError' + } +} + +/** + * Error thrown when attempting to invoke an agent that is already processing an invocation. + * + * This error indicates that invoke() or stream() was called while the agent is already + * executing. Agents can only process one invocation at a time to prevent state corruption. + */ +export class ConcurrentInvocationError extends Error { + /** + * Creates a new ConcurrentInvocationError. + * + * @param message - Error message describing the concurrent invocation attempt + */ + constructor(message: string) { + super(message) + this.name = 'ConcurrentInvocationError' + } +} + +/** + * Error thrown when a model provider returns a throttling or rate limit error. + * + * This error indicates that the model API has rate limited the request. Users can + * handle this error in hooks to implement custom retry strategies using the + * `AfterModelCallEvent.retry` mechanism. + */ +export class ModelThrottledError extends ModelError { + /** + * Creates a new ModelThrottledError. + * + * @param message - Error message describing the throttling condition + * @param options - Optional error options including cause for error chaining + */ + constructor(message: string, options?: ErrorOptions) { + super(message, options) + this.name = 'ModelThrottledError' + } +} + +/** + * Normalizes an unknown error value to an Error instance. + * + * This helper ensures that any thrown value (Error, string, number, etc.) + * is converted to a proper Error object for consistent error handling. + * + * @param error - The error value to normalize + * @returns An Error instance + */ +export function normalizeError(error: unknown): Error { + return error instanceof Error ? error : new Error(String(error)) +} + +/** + * Serializes an Error to a JSON-compatible value. + * Use {@link normalizeError} for the reverse direction. + */ +export function serializeError(error: Error): JSONValue { + return error.message +} + +/** + * Error thrown when session operations fail. + * + * This error indicates failures in session storage operations such as + * reading, writing, or managing session data. + */ +export class SessionError extends Error { + /** + * Creates a new SessionError. + * + * @param message - Error message describing the session error + * @param options - Optional error options including cause for error chaining + */ + constructor(message: string, options?: ErrorOptions) { + super(message, options) + this.name = 'SessionError' + } +} + +/** + * Thrown when a model provider's native token counting API fails. + * + * This error is used as internal control flow within provider `countTokens()` overrides. + * When caught, the provider falls back to the base class heuristic estimation. + */ +export class ProviderTokenCountError extends ModelError { + constructor(message: string, options?: { cause?: unknown }) { + super(message, options) + this.name = 'ProviderTokenCountError' + } +} + +/** + * Thrown when a tool fails validation during registration. + */ +export class ToolValidationError extends Error { + constructor(message: string) { + super(message) + this.name = 'ToolValidationError' + } +} + +/** + * Thrown when the model fails to produce structured output. + * This occurs when the LLM refuses to use the structured output tool + * even after being forced via toolChoice. + */ +export class StructuredOutputError extends Error { + constructor(message: string) { + super(message) + this.name = 'StructuredOutputError' + } +} + +/** + * Error thrown when a tool cannot be found by name. + * + * Thrown by {@link ToolRegistry.resolve} when the requested tool name doesn't + * match any registered tool, even after underscore-to-hyphen normalization + * and case-insensitive matching. + */ +export class ToolNotFoundError extends Error { + /** The tool name that was requested but not found. */ + public readonly toolName: string + + /** + * Creates a new ToolNotFoundError. + * + * @param toolName - The tool name that was not found + */ + constructor(toolName: string) { + super(`Tool '${toolName}' not found`) + this.name = 'ToolNotFoundError' + this.toolName = toolName + } +} + +/** + * Internal control-flow mechanism for unwinding nested `yield*` generator chains + * when cancellation is detected during model streaming. + * Caught at the `_stream()` level and converted to an `AgentResult` with `stopReason: 'cancelled'`. + * Not exported from the package — never thrown to users. + */ +export class CancelledError extends Error { + constructor() { + super('Agent invocation cancelled') + this.name = 'CancelledError' + } +} diff --git a/strands-ts/src/hooks/__tests__/events.test.ts b/strands-ts/src/hooks/__tests__/events.test.ts new file mode 100644 index 0000000000..c952a704e9 --- /dev/null +++ b/strands-ts/src/hooks/__tests__/events.test.ts @@ -0,0 +1,1226 @@ +import { describe, expect, it } from 'vitest' +import { + InitializedEvent, + AfterInvocationEvent, + AfterModelCallEvent, + AfterToolCallEvent, + AfterToolsEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, + BeforeToolCallEvent, + BeforeToolsEvent, + MessageAddedEvent, + ModelStreamUpdateEvent, + ContentBlockEvent, + ModelMessageEvent, + ToolResultEvent, + ToolStreamUpdateEvent, + AgentResultEvent, +} from '../events.js' +import { Agent } from '../../agent/agent.js' +import { AgentResult } from '../../types/agent.js' +import { AgentMetrics } from '../../telemetry/meter.js' +import { Message, TextBlock, ToolResultBlock, ToolUseBlock } from '../../types/messages.js' +import { FunctionTool } from '../../tools/function-tool.js' +import { ToolStreamEvent } from '../../tools/tool.js' + +describe('InitializedEvent', () => { + it('creates instance with correct properties', () => { + const agent = new Agent() + const event = new InitializedEvent({ agent }) + + expect(event).toEqual({ + type: 'initializedEvent', + agent: agent, + }) + // @ts-expect-error verifying that property is readonly + event.agent = new Agent() + }) + + it('returns false for _shouldReverseCallbacks', () => { + const agent = new Agent() + const event = new InitializedEvent({ agent }) + expect(event._shouldReverseCallbacks()).toBe(false) + }) +}) + +describe('BeforeInvocationEvent', () => { + it('creates instance with correct properties', () => { + const agent = new Agent() + const event = new BeforeInvocationEvent({ agent, invocationState: {} }) + + expect(event).toEqual({ + type: 'beforeInvocationEvent', + agent: agent, + cancel: false, + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + event.agent = new Agent() + }) + + it('returns false for _shouldReverseCallbacks', () => { + const agent = new Agent() + const event = new BeforeInvocationEvent({ agent, invocationState: {} }) + expect(event._shouldReverseCallbacks()).toBe(false) + }) + + it('allows cancel to be set to true', () => { + const agent = new Agent() + const event = new BeforeInvocationEvent({ agent, invocationState: {} }) + + expect(event.cancel).toBe(false) + event.cancel = true + expect(event.cancel).toBe(true) + }) + + it('allows cancel to be set to a string message', () => { + const agent = new Agent() + const event = new BeforeInvocationEvent({ agent, invocationState: {} }) + + event.cancel = 'unauthorized' + expect(event.cancel).toBe('unauthorized') + }) +}) + +describe('AfterInvocationEvent', () => { + it('creates instance with correct properties', () => { + const agent = new Agent() + const event = new AfterInvocationEvent({ agent, invocationState: {} }) + + expect(event).toEqual({ + type: 'afterInvocationEvent', + agent: agent, + invocationState: {}, + resume: undefined, + }) + // @ts-expect-error verifying that property is readonly + event.agent = new Agent() + }) + + it('returns true for _shouldReverseCallbacks', () => { + const agent = new Agent() + const event = new AfterInvocationEvent({ agent, invocationState: {} }) + expect(event._shouldReverseCallbacks()).toBe(true) + }) + + it('allows resume to be set to new input', () => { + const agent = new Agent() + const event = new AfterInvocationEvent({ agent, invocationState: {} }) + + expect(event.resume).toBeUndefined() + + event.resume = 'follow-up prompt' + expect(event.resume).toBe('follow-up prompt') + }) +}) + +describe('MessageAddedEvent', () => { + it('creates instance with correct properties', () => { + const agent = new Agent() + const message = new Message({ role: 'assistant', content: [new TextBlock('Hello')] }) + const event = new MessageAddedEvent({ agent, message, invocationState: {} }) + + expect(event).toEqual({ + type: 'messageAddedEvent', + agent: agent, + message: message, + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + event.agent = new Agent() + // @ts-expect-error verifying that property is readonly + event.message = message + }) + + it('returns false for _shouldReverseCallbacks', () => { + const agent = new Agent() + const message = new Message({ role: 'assistant', content: [] }) + const event = new MessageAddedEvent({ agent, message, invocationState: {} }) + expect(event._shouldReverseCallbacks()).toBe(false) + }) +}) + +describe('BeforeToolCallEvent', () => { + it('creates instance with correct properties when tool is found', () => { + const agent = new Agent() + const tool = new FunctionTool({ + name: 'testTool', + description: 'Test tool', + inputSchema: {}, + callback: () => 'result', + }) + const toolUse = { + name: 'testTool', + toolUseId: 'test-id', + input: { arg: 'value' }, + } + const event = new BeforeToolCallEvent({ agent, toolUse, tool, invocationState: {} }) + + expect(event).toEqual({ + type: 'beforeToolCallEvent', + agent: agent, + toolUse: toolUse, + tool: tool, + cancel: false, + invocationState: {}, + selectedTool: undefined, + }) + // @ts-expect-error verifying that property is readonly + event.agent = new Agent() + // @ts-expect-error verifying that property is readonly + event.tool = tool + }) + + it('creates instance with undefined tool when tool is not found', () => { + const agent = new Agent() + const toolUse = { + name: 'unknownTool', + toolUseId: 'test-id', + input: {}, + } + const event = new BeforeToolCallEvent({ agent, toolUse, tool: undefined, invocationState: {} }) + + expect(event).toEqual({ + type: 'beforeToolCallEvent', + agent: agent, + toolUse: toolUse, + tool: undefined, + cancel: false, + invocationState: {}, + selectedTool: undefined, + }) + }) + + it('returns false for _shouldReverseCallbacks', () => { + const agent = new Agent() + const toolUse = { name: 'test', toolUseId: 'id', input: {} } + const event = new BeforeToolCallEvent({ agent, toolUse, tool: undefined, invocationState: {} }) + expect(event._shouldReverseCallbacks()).toBe(false) + }) + + it('allows cancel to be set to true', () => { + const agent = new Agent() + const toolUse = { name: 'test', toolUseId: 'id', input: {} } + const event = new BeforeToolCallEvent({ agent, toolUse, tool: undefined, invocationState: {} }) + + expect(event.cancel).toBe(false) + event.cancel = true + expect(event.cancel).toBe(true) + }) + + it('allows cancel to be set to a string message', () => { + const agent = new Agent() + const toolUse = { name: 'test', toolUseId: 'id', input: {} } + const event = new BeforeToolCallEvent({ agent, toolUse, tool: undefined, invocationState: {} }) + + event.cancel = 'tool not allowed' + expect(event.cancel).toBe('tool not allowed') + }) + + it('allows selectedTool to be set to a replacement tool', () => { + const agent = new Agent() + const toolUse = { name: 'test', toolUseId: 'id', input: {} } + const event = new BeforeToolCallEvent({ agent, toolUse, tool: undefined, invocationState: {} }) + + expect(event.selectedTool).toBeUndefined() + + const replacement = new FunctionTool({ + name: 'replacement', + description: 'Replacement', + inputSchema: {}, + callback: () => 'ok', + }) + event.selectedTool = replacement + expect(event.selectedTool).toBe(replacement) + }) + + it('allows mutating toolUse fields in-place', () => { + const agent = new Agent() + const toolUse = { name: 'orig', toolUseId: 'id', input: { a: 1 } } + const event = new BeforeToolCallEvent({ agent, toolUse, tool: undefined, invocationState: {} }) + + event.toolUse.input = { a: 2, b: 3 } + event.toolUse.name = 'renamed' + expect(event.toolUse).toEqual({ name: 'renamed', toolUseId: 'id', input: { a: 2, b: 3 } }) + }) + + it('allows reassigning toolUse to a new object', () => { + const agent = new Agent() + const toolUse = { name: 'orig', toolUseId: 'id', input: {} } + const event = new BeforeToolCallEvent({ agent, toolUse, tool: undefined, invocationState: {} }) + + event.toolUse = { name: 'new', toolUseId: 'new-id', input: { x: 1 } } + expect(event.toolUse).toEqual({ name: 'new', toolUseId: 'new-id', input: { x: 1 } }) + }) +}) + +describe('AfterToolCallEvent', () => { + it('creates instance with correct properties on success', () => { + const agent = new Agent() + const tool = new FunctionTool({ + name: 'testTool', + description: 'Test tool', + inputSchema: {}, + callback: () => 'result', + }) + const toolUse = { + name: 'testTool', + toolUseId: 'test-id', + input: {}, + } + const result = new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('Success')], + }) + const event = new AfterToolCallEvent({ agent, toolUse, tool, result, invocationState: {} }) + + expect(event).toEqual({ + type: 'afterToolCallEvent', + agent: agent, + toolUse: toolUse, + tool: tool, + result: result, + error: undefined, + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + event.agent = new Agent() + // @ts-expect-error verifying that property is readonly + event.toolUse = toolUse + // @ts-expect-error verifying that property is readonly + event.tool = tool + }) + + it('allows result to be replaced', () => { + const agent = new Agent() + const toolUse = { name: 'test', toolUseId: 'id', input: {} } + const result = new ToolResultBlock({ toolUseId: 'id', status: 'success', content: [new TextBlock('original')] }) + const event = new AfterToolCallEvent({ agent, toolUse, tool: undefined, result, invocationState: {} }) + + const replacedResult = new ToolResultBlock({ + toolUseId: 'id', + status: 'success', + content: [new TextBlock('replaced')], + }) + event.result = replacedResult + expect(event.result).toBe(replacedResult) + }) + + it('creates instance with error property when tool execution fails', () => { + const agent = new Agent() + const toolUse = { name: 'test', toolUseId: 'id', input: {} } + const result = new ToolResultBlock({ + toolUseId: 'id', + status: 'error', + content: [new TextBlock('Error')], + }) + const error = new Error('Tool failed') + const event = new AfterToolCallEvent({ agent, toolUse, tool: undefined, result, error, invocationState: {} }) + + expect(event).toEqual({ + type: 'afterToolCallEvent', + agent: agent, + toolUse: toolUse, + tool: undefined, + result: result, + error: error, + invocationState: {}, + }) + }) + + it('returns true for _shouldReverseCallbacks', () => { + const agent = new Agent() + const toolUse = { name: 'test', toolUseId: 'id', input: {} } + const result = new ToolResultBlock({ + toolUseId: 'id', + status: 'success', + content: [], + }) + const event = new AfterToolCallEvent({ agent, toolUse, tool: undefined, result, invocationState: {} }) + expect(event._shouldReverseCallbacks()).toBe(true) + }) + + it('allows retry to be set when error is present', () => { + const agent = new Agent() + const toolUse = { name: 'test', toolUseId: 'id', input: {} } + const result = new ToolResultBlock({ + toolUseId: 'id', + status: 'error', + content: [new TextBlock('Error')], + }) + const error = new Error('Tool failed') + const event = new AfterToolCallEvent({ agent, toolUse, tool: undefined, result, error, invocationState: {} }) + + expect(event.retry).toBeUndefined() + + event.retry = true + expect(event.retry).toBe(true) + + event.retry = false + expect(event.retry).toBe(false) + }) + + it('allows retry to be set on success', () => { + const agent = new Agent() + const toolUse = { name: 'test', toolUseId: 'id', input: {} } + const result = new ToolResultBlock({ + toolUseId: 'id', + status: 'success', + content: [new TextBlock('Success')], + }) + const event = new AfterToolCallEvent({ agent, toolUse, tool: undefined, result, invocationState: {} }) + + expect(event.retry).toBeUndefined() + + event.retry = true + expect(event.retry).toBe(true) + }) +}) + +describe('BeforeModelCallEvent', () => { + it('creates instance with correct properties', () => { + const agent = new Agent() + const event = new BeforeModelCallEvent({ agent, model: agent.model, invocationState: {} }) + + expect(event).toEqual({ + type: 'beforeModelCallEvent', + agent: agent, + model: agent.model, + cancel: false, + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + event.agent = new Agent() + }) + + it('includes projectedInputTokens when provided', () => { + const agent = new Agent() + const event = new BeforeModelCallEvent({ + agent, + model: agent.model, + invocationState: {}, + projectedInputTokens: 500, + }) + + expect(event).toEqual({ + type: 'beforeModelCallEvent', + agent, + model: agent.model, + cancel: false, + invocationState: {}, + projectedInputTokens: 500, + }) + expect(event.toJSON()).toStrictEqual({ + type: 'beforeModelCallEvent', + projectedInputTokens: 500, + }) + }) + + it('excludes projectedInputTokens from toJSON when not provided', () => { + const agent = new Agent() + const event = new BeforeModelCallEvent({ agent, model: agent.model, invocationState: {} }) + + expect(event.projectedInputTokens).toBeUndefined() + expect(event.toJSON()).toStrictEqual({ type: 'beforeModelCallEvent' }) + }) + + it('returns false for _shouldReverseCallbacks', () => { + const agent = new Agent() + const event = new BeforeModelCallEvent({ agent, model: agent.model, invocationState: {} }) + expect(event._shouldReverseCallbacks()).toBe(false) + }) + + it('allows cancel to be set to true', () => { + const agent = new Agent() + const event = new BeforeModelCallEvent({ agent, model: agent.model, invocationState: {} }) + + expect(event.cancel).toBe(false) + event.cancel = true + expect(event.cancel).toBe(true) + }) + + it('allows cancel to be set to a string message', () => { + const agent = new Agent() + const event = new BeforeModelCallEvent({ agent, model: agent.model, invocationState: {} }) + + event.cancel = 'rate limited' + expect(event.cancel).toBe('rate limited') + }) +}) + +describe('AfterModelCallEvent', () => { + it('creates instance with correct properties on success', () => { + const agent = new Agent() + const message = new Message({ role: 'assistant', content: [new TextBlock('Response')] }) + const stopReason = 'endTurn' + const response = { message, stopReason } + const event = new AfterModelCallEvent({ + agent, + model: agent.model, + attemptCount: 1, + stopData: response, + invocationState: {}, + }) + + expect(event).toEqual({ + type: 'afterModelCallEvent', + agent: agent, + model: agent.model, + attemptCount: 1, + stopData: response, + error: undefined, + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + event.agent = new Agent() + // @ts-expect-error verifying that property is readonly + event.stopData = response + }) + + it('creates instance with error property when model invocation fails', () => { + const agent = new Agent() + const message = new Message({ role: 'assistant', content: [] }) + const error = new Error('Model failed') + const response = { message, stopReason: 'error' } + const event = new AfterModelCallEvent({ + agent, + model: agent.model, + attemptCount: 1, + stopData: response, + error, + invocationState: {}, + }) + + expect(event).toEqual({ + type: 'afterModelCallEvent', + agent: agent, + model: agent.model, + attemptCount: 1, + stopData: response, + error: error, + invocationState: {}, + }) + }) + + it('returns true for _shouldReverseCallbacks', () => { + const agent = new Agent() + const message = new Message({ role: 'assistant', content: [] }) + const response = { message, stopReason: 'endTurn' } + const event = new AfterModelCallEvent({ + agent, + model: agent.model, + attemptCount: 1, + stopData: response, + invocationState: {}, + }) + expect(event._shouldReverseCallbacks()).toBe(true) + }) + + it('allows retry to be set when error is present', () => { + const agent = new Agent() + const error = new Error('Model failed') + const event = new AfterModelCallEvent({ agent, model: agent.model, attemptCount: 1, error, invocationState: {} }) + + // Initially undefined + expect(event.retry).toBeUndefined() + + // Can be set to true + event.retry = true + expect(event.retry).toBe(true) + + // Can be set to false + event.retry = false + expect(event.retry).toBe(false) + }) + + it('retry is optional and defaults to undefined', () => { + const agent = new Agent() + const error = new Error('Model failed') + const event = new AfterModelCallEvent({ agent, model: agent.model, attemptCount: 1, error, invocationState: {} }) + + expect(event.retry).toBeUndefined() + }) +}) + +describe('ModelStreamUpdateEvent', () => { + it('creates instance with correct properties', () => { + const agent = new Agent() + const streamEvent = { + type: 'modelMessageStartEvent' as const, + role: 'assistant' as const, + } + const hookEvent = new ModelStreamUpdateEvent({ agent, event: streamEvent, invocationState: {} }) + + expect(hookEvent).toEqual({ + type: 'modelStreamUpdateEvent', + agent: agent, + event: streamEvent, + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + hookEvent.agent = new Agent() + // @ts-expect-error verifying that property is readonly + hookEvent.event = streamEvent + }) +}) + +describe('ContentBlockEvent', () => { + it('creates instance with correct properties', () => { + const agent = new Agent() + const contentBlock = new TextBlock('Hello') + const event = new ContentBlockEvent({ agent, contentBlock, invocationState: {} }) + + expect(event).toEqual({ + type: 'contentBlockEvent', + agent: agent, + contentBlock: contentBlock, + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + event.agent = new Agent() + // @ts-expect-error verifying that property is readonly + event.contentBlock = contentBlock + }) +}) + +describe('ModelMessageEvent', () => { + it('creates instance with correct properties', () => { + const agent = new Agent() + const message = new Message({ role: 'assistant', content: [new TextBlock('Hello')] }) + const event = new ModelMessageEvent({ agent, message, stopReason: 'endTurn', invocationState: {} }) + + expect(event).toEqual({ + type: 'modelMessageEvent', + agent: agent, + message: message, + stopReason: 'endTurn', + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + event.agent = new Agent() + // @ts-expect-error verifying that property is readonly + event.message = message + // @ts-expect-error verifying that property is readonly + event.stopReason = 'endTurn' + }) +}) + +describe('ToolResultEvent', () => { + it('creates instance with correct properties', () => { + const agent = new Agent() + const toolResult = new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('Result')], + }) + const event = new ToolResultEvent({ agent, result: toolResult, invocationState: {} }) + + expect(event).toEqual({ + type: 'toolResultEvent', + agent: agent, + result: toolResult, + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + event.agent = new Agent() + // @ts-expect-error verifying that property is readonly + event.result = toolResult + }) +}) + +describe('ToolStreamUpdateEvent', () => { + it('creates instance with correct properties', () => { + const agent = new Agent() + const toolStreamEvent = new ToolStreamEvent({ data: 'progress' }) + const event = new ToolStreamUpdateEvent({ agent, event: toolStreamEvent, invocationState: {} }) + + expect(event).toEqual({ + type: 'toolStreamUpdateEvent', + agent: agent, + event: toolStreamEvent, + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + event.agent = new Agent() + // @ts-expect-error verifying that property is readonly + event.event = toolStreamEvent + }) +}) + +describe('AgentResultEvent', () => { + it('creates instance with correct properties', () => { + const agent = new Agent() + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: new Message({ role: 'assistant', content: [new TextBlock('Done')] }), + metrics: new AgentMetrics(), + invocationState: {}, + }) + const event = new AgentResultEvent({ agent, result, invocationState: {} }) + + expect(event).toEqual({ + type: 'agentResultEvent', + agent: agent, + result: result, + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + event.agent = new Agent() + // @ts-expect-error verifying that property is readonly + event.result = result + }) +}) + +describe('BeforeToolsEvent', () => { + it('creates instance with correct properties', () => { + const agent = new Agent() + const message = new Message({ + role: 'assistant', + content: [ + new ToolUseBlock({ + name: 'testTool', + toolUseId: 'test-id', + input: { arg: 'value' }, + }), + ], + }) + const event = new BeforeToolsEvent({ agent, message, invocationState: {} }) + + expect(event).toEqual({ + type: 'beforeToolsEvent', + agent: agent, + message: message, + cancel: false, + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + event.agent = new Agent() + // @ts-expect-error verifying that property is readonly + event.message = message + }) + + it('returns false for _shouldReverseCallbacks', () => { + const agent = new Agent() + const message = new Message({ role: 'assistant', content: [] }) + const event = new BeforeToolsEvent({ agent, message, invocationState: {} }) + expect(event._shouldReverseCallbacks()).toBe(false) + }) + + it('allows cancel to be set to true', () => { + const agent = new Agent() + const message = new Message({ role: 'assistant', content: [] }) + const event = new BeforeToolsEvent({ agent, message, invocationState: {} }) + + expect(event.cancel).toBe(false) + event.cancel = true + expect(event.cancel).toBe(true) + }) + + it('allows cancel to be set to a string message', () => { + const agent = new Agent() + const message = new Message({ role: 'assistant', content: [] }) + const event = new BeforeToolsEvent({ agent, message, invocationState: {} }) + + event.cancel = 'tools not allowed' + expect(event.cancel).toBe('tools not allowed') + }) +}) + +describe('AfterToolsEvent', () => { + it('creates instance with correct properties', () => { + const agent = new Agent() + const message = new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('Result')], + }), + ], + }) + const event = new AfterToolsEvent({ agent, message, invocationState: {} }) + + expect(event).toEqual({ + type: 'afterToolsEvent', + agent: agent, + message: message, + invocationState: {}, + endTurn: false, + }) + // @ts-expect-error verifying that property is readonly + event.agent = new Agent() + // @ts-expect-error verifying that property is readonly + event.message = message + }) + + it('returns true for _shouldReverseCallbacks', () => { + const agent = new Agent() + const message = new Message({ role: 'user', content: [] }) + const event = new AfterToolsEvent({ agent, message, invocationState: {} }) + expect(event._shouldReverseCallbacks()).toBe(true) + }) + + it('defaults endTurn to false and accepts boolean or string', () => { + const agent = new Agent() + const message = new Message({ role: 'user', content: [] }) + const event = new AfterToolsEvent({ agent, message, invocationState: {} }) + + expect(event.endTurn).toBe(false) + + event.endTurn = true + expect(event.endTurn).toBe(true) + + event.endTurn = 'enough information gathered' + expect(event.endTurn).toBe('enough information gathered') + }) +}) + +// ===================== toJSON serialization tests ===================== + +describe('toJSON serialization', () => { + describe('InitializedEvent', () => { + it('excludes agent and returns only type', () => { + const agent = new Agent() + const event = new InitializedEvent({ agent }) + const json = JSON.parse(JSON.stringify(event)) + + expect(json).toStrictEqual({ type: 'initializedEvent' }) + }) + }) + + describe('BeforeInvocationEvent', () => { + it('excludes agent and returns only type', () => { + const agent = new Agent() + const event = new BeforeInvocationEvent({ agent, invocationState: {} }) + const json = JSON.parse(JSON.stringify(event)) + + expect(json).toStrictEqual({ type: 'beforeInvocationEvent' }) + }) + }) + + describe('AfterInvocationEvent', () => { + it('excludes agent and returns only type', () => { + const agent = new Agent() + const event = new AfterInvocationEvent({ agent, invocationState: {} }) + const json = JSON.parse(JSON.stringify(event)) + + expect(json).toStrictEqual({ type: 'afterInvocationEvent' }) + }) + }) + + describe('BeforeModelCallEvent', () => { + it('excludes agent and model and returns only type', () => { + const agent = new Agent() + const event = new BeforeModelCallEvent({ agent, model: agent.model, invocationState: {} }) + const json = JSON.parse(JSON.stringify(event)) + + expect(json).toStrictEqual({ type: 'beforeModelCallEvent' }) + }) + }) + + describe('MessageAddedEvent', () => { + it('includes message and excludes agent', () => { + const agent = new Agent() + const message = new Message({ role: 'assistant', content: [new TextBlock('Hello')] }) + const event = new MessageAddedEvent({ agent, message, invocationState: {} }) + const json = JSON.parse(JSON.stringify(event)) + + expect(json).toStrictEqual({ + type: 'messageAddedEvent', + message: { role: 'assistant', content: [{ text: 'Hello' }] }, + }) + }) + }) + + describe('ModelStreamUpdateEvent', () => { + it('includes stream event and excludes agent', () => { + const agent = new Agent() + const streamEvent = { + type: 'modelContentBlockDeltaEvent' as const, + delta: { type: 'textDelta' as const, text: 'Hi' }, + } + const event = new ModelStreamUpdateEvent({ agent, event: streamEvent, invocationState: {} }) + const json = JSON.parse(JSON.stringify(event)) + + expect(json).toStrictEqual({ + type: 'modelStreamUpdateEvent', + event: { type: 'modelContentBlockDeltaEvent', delta: { type: 'textDelta', text: 'Hi' } }, + }) + }) + }) + + describe('ContentBlockEvent', () => { + it('includes content block and excludes agent', () => { + const agent = new Agent() + const contentBlock = new TextBlock('Hello world') + const event = new ContentBlockEvent({ agent, contentBlock, invocationState: {} }) + const json = JSON.parse(JSON.stringify(event)) + + expect(json).toStrictEqual({ + type: 'contentBlockEvent', + contentBlock: { text: 'Hello world' }, + }) + }) + }) + + describe('ModelMessageEvent', () => { + it('includes message and stopReason, excludes agent', () => { + const agent = new Agent() + const message = new Message({ role: 'assistant', content: [new TextBlock('Done')] }) + const event = new ModelMessageEvent({ agent, message, stopReason: 'endTurn', invocationState: {} }) + const json = JSON.parse(JSON.stringify(event)) + + expect(json).toStrictEqual({ + type: 'modelMessageEvent', + message: { role: 'assistant', content: [{ text: 'Done' }] }, + stopReason: 'endTurn', + }) + }) + }) + + describe('ToolResultEvent', () => { + it('includes result and excludes agent', () => { + const agent = new Agent() + const result = new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new TextBlock('42')], + }) + const event = new ToolResultEvent({ agent, result, invocationState: {} }) + const json = JSON.parse(JSON.stringify(event)) + + expect(json).toStrictEqual({ + type: 'toolResultEvent', + result: { toolResult: { toolUseId: 'tool-1', status: 'success', content: [{ text: '42' }] } }, + }) + }) + }) + + describe('ToolStreamUpdateEvent', () => { + it('includes tool stream event and excludes agent', () => { + const agent = new Agent() + const toolStreamEvent = new ToolStreamEvent({ data: { progress: 50 } }) + const event = new ToolStreamUpdateEvent({ agent, event: toolStreamEvent, invocationState: {} }) + const json = JSON.parse(JSON.stringify(event)) + + expect(json).toStrictEqual({ + type: 'toolStreamUpdateEvent', + event: { type: 'toolStreamEvent', data: { progress: 50 } }, + }) + }) + }) + + describe('AgentResultEvent', () => { + it('includes result and excludes agent', () => { + const agent = new Agent() + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: new Message({ role: 'assistant', content: [new TextBlock('Done')] }), + metrics: new AgentMetrics(), + invocationState: {}, + }) + const event = new AgentResultEvent({ agent, result, invocationState: {} }) + const json = JSON.parse(JSON.stringify(event)) + + expect(json).toStrictEqual({ + type: 'agentResultEvent', + result: { + type: 'agentResult', + stopReason: 'endTurn', + lastMessage: { role: 'assistant', content: [{ text: 'Done' }] }, + }, + }) + }) + }) + + describe('BeforeToolCallEvent', () => { + it('includes toolUse and excludes agent, tool, and cancel', () => { + const agent = new Agent() + const tool = new FunctionTool({ + name: 'testTool', + description: 'Test', + inputSchema: {}, + callback: () => 'result', + }) + const toolUse = { name: 'testTool', toolUseId: 'id-1', input: { query: 'hello' } } + const event = new BeforeToolCallEvent({ agent, toolUse, tool, invocationState: {} }) + const json = JSON.parse(JSON.stringify(event)) + + expect(json).toStrictEqual({ + type: 'beforeToolCallEvent', + toolUse: { name: 'testTool', toolUseId: 'id-1', input: { query: 'hello' } }, + }) + }) + }) + + describe('AfterToolCallEvent', () => { + it('includes toolUse and result, excludes agent and tool on success', () => { + const agent = new Agent() + const toolUse = { name: 'calc', toolUseId: 'id-1', input: {} } + const result = new ToolResultBlock({ + toolUseId: 'id-1', + status: 'success', + content: [new TextBlock('42')], + }) + const event = new AfterToolCallEvent({ agent, toolUse, tool: undefined, result, invocationState: {} }) + const json = JSON.parse(JSON.stringify(event)) + + expect(json).toStrictEqual({ + type: 'afterToolCallEvent', + toolUse: { name: 'calc', toolUseId: 'id-1', input: {} }, + result: { toolResult: { toolUseId: 'id-1', status: 'success', content: [{ text: '42' }] } }, + }) + }) + + it('converts error to message string and excludes retry', () => { + const agent = new Agent() + const toolUse = { name: 'calc', toolUseId: 'id-1', input: {} } + const result = new ToolResultBlock({ + toolUseId: 'id-1', + status: 'error', + content: [new TextBlock('Error')], + }) + const error = new Error('Tool crashed') + const event = new AfterToolCallEvent({ agent, toolUse, tool: undefined, result, error, invocationState: {} }) + event.retry = true + const json = JSON.parse(JSON.stringify(event)) + + expect(json).toStrictEqual({ + type: 'afterToolCallEvent', + toolUse: { name: 'calc', toolUseId: 'id-1', input: {} }, + result: { toolResult: { toolUseId: 'id-1', status: 'error', content: [{ text: 'Error' }] } }, + error: { message: 'Tool crashed' }, + }) + }) + }) + + describe('AfterModelCallEvent', () => { + it('includes stopData and attemptCount and excludes agent and model on success', () => { + const agent = new Agent() + const message = new Message({ role: 'assistant', content: [new TextBlock('Hi')] }) + const stopData = { message, stopReason: 'endTurn' as const } + const event = new AfterModelCallEvent({ + agent, + model: agent.model, + attemptCount: 2, + stopData, + invocationState: {}, + }) + const json = JSON.parse(JSON.stringify(event)) + + expect(json).toStrictEqual({ + type: 'afterModelCallEvent', + attemptCount: 2, + stopData: { + message: { role: 'assistant', content: [{ text: 'Hi' }] }, + stopReason: 'endTurn', + }, + }) + }) + + it('converts error to message string and excludes retry', () => { + const agent = new Agent() + const error = new Error('Model failed') + const event = new AfterModelCallEvent({ agent, model: agent.model, attemptCount: 1, error, invocationState: {} }) + event.retry = true + const json = JSON.parse(JSON.stringify(event)) + + expect(json).toStrictEqual({ + type: 'afterModelCallEvent', + attemptCount: 1, + error: { message: 'Model failed' }, + }) + }) + }) + + describe('BeforeToolsEvent', () => { + it('includes message and excludes agent and cancel', () => { + const agent = new Agent() + const message = new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'calc', toolUseId: 'id-1', input: {} })], + }) + const event = new BeforeToolsEvent({ agent, message, invocationState: {} }) + event.cancel = 'not allowed' + const json = JSON.parse(JSON.stringify(event)) + + expect(json).toStrictEqual({ + type: 'beforeToolsEvent', + message: { role: 'assistant', content: [{ toolUse: { name: 'calc', toolUseId: 'id-1', input: {} } }] }, + }) + }) + }) + + describe('AfterToolsEvent', () => { + it('includes message and excludes agent', () => { + const agent = new Agent() + const message = new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'id-1', + status: 'success', + content: [new TextBlock('Done')], + }), + ], + }) + const event = new AfterToolsEvent({ agent, message, invocationState: {} }) + const json = JSON.parse(JSON.stringify(event)) + + expect(json).toStrictEqual({ + type: 'afterToolsEvent', + message: { + role: 'user', + content: [{ toolResult: { toolUseId: 'id-1', status: 'success', content: [{ text: 'Done' }] } }], + }, + }) + }) + }) + + describe('agent reference is never serialized', () => { + it('JSON.stringify output never contains agent properties', () => { + const agent = new Agent() + // Add messages to make agent heavy + agent.messages.push(new Message({ role: 'user', content: [new TextBlock('Hello '.repeat(100))] })) + + const event = new ModelStreamUpdateEvent({ + agent, + event: { type: 'modelContentBlockDeltaEvent', delta: { type: 'textDelta', text: 'Hi' } }, + invocationState: {}, + }) + const json = JSON.stringify(event) + + // Should be small (no agent serialized) + expect(json.length).toBeLessThan(200) + expect(json).not.toContain('Hello Hello') + expect(json).not.toContain('appState') + expect(json).not.toContain('toolRegistry') + }) + }) +}) + +// ===================== Serialization completeness tests ===================== +// Ensures that if a new field is added to an event class, it must either be +// included in toJSON() or explicitly added to the exclusion set. + +describe('toJSON serialization completeness', () => { + /** + * Fields that should NEVER appear in toJSON() output. + * If you add a new field to an event and it should be excluded from wire serialization, + * add it here. Otherwise, add it to toJSON() so it gets serialized. + */ + const EXCLUDED_FIELDS = new Set([ + 'agent', + 'model', + 'tool', + 'cancel', + 'retry', + 'invocationState', + 'selectedTool', + 'resume', + 'endTurn', + ]) + + /** + * Fields where toJSON() transforms the value (e.g., Error to message object). + * These appear in both instance and JSON but with different shapes. + */ + const TRANSFORMED_FIELDS = new Set(['error']) + + // Helper: create a fully-populated instance of each event class + function createEventInstances(): Array<{ name: string; event: { toJSON(): Record } }> { + const agent = new Agent() + const message = new Message({ role: 'assistant', content: [new TextBlock('test')] }) + const toolUse = { name: 'test', toolUseId: 'id-1', input: {} } + const result = new ToolResultBlock({ toolUseId: 'id-1', status: 'success', content: [new TextBlock('ok')] }) + const tool = new FunctionTool({ name: 'test', description: 'Test', inputSchema: {}, callback: () => 'ok' }) + const error = new Error('test error') + const stopData = { message, stopReason: 'endTurn' as const } + const streamEvent = { + type: 'modelContentBlockDeltaEvent' as const, + delta: { type: 'textDelta' as const, text: 'Hi' }, + } + const contentBlock = new TextBlock('test') + const toolStreamEvent = new ToolStreamEvent({ data: { progress: 50 } }) + const agentResult = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + metrics: new AgentMetrics(), + invocationState: {}, + }) + + return [ + { name: 'InitializedEvent', event: new InitializedEvent({ agent }) }, + { name: 'BeforeInvocationEvent', event: new BeforeInvocationEvent({ agent, invocationState: {} }) }, + { name: 'AfterInvocationEvent', event: new AfterInvocationEvent({ agent, invocationState: {} }) }, + { + name: 'BeforeModelCallEvent', + event: new BeforeModelCallEvent({ + agent, + model: agent.model, + invocationState: {}, + projectedInputTokens: 100, + }), + }, + { + name: 'AfterModelCallEvent', + event: Object.assign( + new AfterModelCallEvent({ agent, model: agent.model, attemptCount: 1, stopData, error, invocationState: {} }), + { retry: true } + ), + }, + { name: 'MessageAddedEvent', event: new MessageAddedEvent({ agent, message, invocationState: {} }) }, + { + name: 'ModelStreamUpdateEvent', + event: new ModelStreamUpdateEvent({ agent, event: streamEvent, invocationState: {} }), + }, + { name: 'ContentBlockEvent', event: new ContentBlockEvent({ agent, contentBlock, invocationState: {} }) }, + { + name: 'ModelMessageEvent', + event: new ModelMessageEvent({ agent, message, stopReason: 'endTurn', invocationState: {} }), + }, + { name: 'ToolResultEvent', event: new ToolResultEvent({ agent, result, invocationState: {} }) }, + { + name: 'ToolStreamUpdateEvent', + event: new ToolStreamUpdateEvent({ agent, event: toolStreamEvent, invocationState: {} }), + }, + { name: 'AgentResultEvent', event: new AgentResultEvent({ agent, result: agentResult, invocationState: {} }) }, + { name: 'BeforeToolCallEvent', event: new BeforeToolCallEvent({ agent, toolUse, tool, invocationState: {} }) }, + { + name: 'AfterToolCallEvent', + event: Object.assign(new AfterToolCallEvent({ agent, toolUse, tool, result, error, invocationState: {} }), { + retry: true, + }), + }, + { name: 'BeforeToolsEvent', event: new BeforeToolsEvent({ agent, message, invocationState: {} }) }, + { name: 'AfterToolsEvent', event: new AfterToolsEvent({ agent, message, invocationState: {} }) }, + ] + } + + const eventInstances = createEventInstances() + + it.each(eventInstances)('$name: toJSON() includes all fields except known exclusions', ({ event }) => { + const instanceKeys = new Set(Object.keys(event)) + const jsonKeys = new Set(Object.keys(event.toJSON())) + + // Every instance key should either be in JSON output, in the exclusion set, or transformed + for (const key of instanceKeys) { + if (!jsonKeys.has(key) && !TRANSFORMED_FIELDS.has(key)) { + expect(EXCLUDED_FIELDS).toContain(key) + } + } + + // Every JSON key should come from the instance or be a known transformation + for (const key of jsonKeys) { + expect(instanceKeys.has(key) || TRANSFORMED_FIELDS.has(key)).toBe(true) + } + }) + + it.each(eventInstances)('$name: toJSON() never includes agent', ({ event }) => { + const json = event.toJSON() + expect(json).not.toHaveProperty('agent') + }) +}) diff --git a/strands-ts/src/hooks/__tests__/registry.test.ts b/strands-ts/src/hooks/__tests__/registry.test.ts new file mode 100644 index 0000000000..3291cd06a6 --- /dev/null +++ b/strands-ts/src/hooks/__tests__/registry.test.ts @@ -0,0 +1,464 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { HookRegistryImplementation } from '../registry.js' +import { AfterInvocationEvent, BeforeInvocationEvent, BeforeToolCallEvent } from '../events.js' +import { Agent } from '../../agent/agent.js' +import { InterruptError, InterruptState } from '../../interrupt.js' + +describe('HookRegistryImplementation', () => { + let registry: HookRegistryImplementation + let mockAgent: Agent + + const getInterruptState = (agent: Agent): InterruptState => + (agent as unknown as { _interruptState: InterruptState })._interruptState + + beforeEach(() => { + registry = new HookRegistryImplementation() + mockAgent = new Agent() + }) + + describe('addCallback', () => { + it('registers callback for event type', async () => { + const callback = vi.fn() + registry.addCallback(BeforeInvocationEvent, callback) + + await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent, invocationState: {} })) + + expect(callback).toHaveBeenCalledOnce() + }) + + it('registers multiple callbacks for same event type', async () => { + const callback1 = vi.fn() + const callback2 = vi.fn() + + registry.addCallback(BeforeInvocationEvent, callback1) + registry.addCallback(BeforeInvocationEvent, callback2) + + await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent, invocationState: {} })) + + expect(callback1).toHaveBeenCalledOnce() + expect(callback2).toHaveBeenCalledOnce() + }) + + it('registers callbacks for different event types separately', async () => { + const beforeCallback = vi.fn() + const afterCallback = vi.fn() + + registry.addCallback(BeforeInvocationEvent, beforeCallback) + registry.addCallback(AfterInvocationEvent, afterCallback) + + await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent, invocationState: {} })) + + expect(beforeCallback).toHaveBeenCalledOnce() + expect(afterCallback).not.toHaveBeenCalled() + + await registry.invokeCallbacks(new AfterInvocationEvent({ agent: mockAgent, invocationState: {} })) + + expect(afterCallback).toHaveBeenCalledOnce() + }) + }) + + describe('invokeCallbacks', () => { + it('calls registered callbacks in order', async () => { + const callOrder: number[] = [] + const callback1 = vi.fn(() => { + callOrder.push(1) + }) + const callback2 = vi.fn(() => { + callOrder.push(2) + }) + + registry.addCallback(BeforeInvocationEvent, callback1) + registry.addCallback(BeforeInvocationEvent, callback2) + + await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent, invocationState: {} })) + + expect(callOrder).toEqual([1, 2]) + }) + + it('reverses callback order for After events', async () => { + const callOrder: number[] = [] + const callback1 = vi.fn(() => { + callOrder.push(1) + }) + const callback2 = vi.fn(() => { + callOrder.push(2) + }) + + registry.addCallback(AfterInvocationEvent, callback1) + registry.addCallback(AfterInvocationEvent, callback2) + + await registry.invokeCallbacks(new AfterInvocationEvent({ agent: mockAgent, invocationState: {} })) + + expect(callOrder).toEqual([2, 1]) + }) + + it('awaits async callbacks', async () => { + let completed = false + const callback = vi.fn(async (): Promise => { + await new Promise((resolve) => globalThis.setTimeout(resolve, 10)) + completed = true + }) + + registry.addCallback(BeforeInvocationEvent, callback) + + await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent, invocationState: {} })) + + expect(completed).toBe(true) + }) + + it('propagates callback errors', async () => { + const callback = vi.fn(() => { + throw new Error('Hook failed') + }) + + registry.addCallback(BeforeInvocationEvent, callback) + + await expect( + registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent, invocationState: {} })) + ).rejects.toThrow('Hook failed') + }) + + it('stops execution on first non-interrupt error', async () => { + const callback1 = vi.fn(() => { + throw new Error('First callback failed') + }) + const callback2 = vi.fn() + + registry.addCallback(BeforeInvocationEvent, callback1) + registry.addCallback(BeforeInvocationEvent, callback2) + + await expect( + registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent, invocationState: {} })) + ).rejects.toThrow('First callback failed') + + expect(callback2).not.toHaveBeenCalled() + }) + + it('handles mixed sync and async callbacks', async () => { + const callOrder: string[] = [] + const syncCallback = vi.fn(() => { + callOrder.push('sync') + }) + const asyncCallback = vi.fn(async (): Promise => { + await new Promise((resolve) => globalThis.globalThis.setTimeout(resolve, 10)) + callOrder.push('async') + }) + + registry.addCallback(BeforeInvocationEvent, syncCallback) + registry.addCallback(BeforeInvocationEvent, asyncCallback) + + await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent, invocationState: {} })) + + expect(callOrder).toEqual(['sync', 'async']) + }) + + it('returns the event after invocation', async () => { + const event = new BeforeInvocationEvent({ agent: mockAgent, invocationState: {} }) + const result = await registry.invokeCallbacks(event) + expect(result).toBe(event) + }) + }) + + describe('ordering', () => { + it('lower order runs first', async () => { + const callOrder: number[] = [] + registry.addCallback(BeforeInvocationEvent, () => { + callOrder.push(0) + }) + registry.addCallback( + BeforeInvocationEvent, + () => { + callOrder.push(100) + }, + { order: 100 } + ) + registry.addCallback( + BeforeInvocationEvent, + () => { + callOrder.push(-100) + }, + { order: -100 } + ) + + await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent, invocationState: {} })) + + expect(callOrder).toEqual([-100, 0, 100]) + }) + + it('same order preserves registration order', async () => { + const callOrder: string[] = [] + registry.addCallback( + BeforeInvocationEvent, + () => { + callOrder.push('first') + }, + { order: 10 } + ) + registry.addCallback( + BeforeInvocationEvent, + () => { + callOrder.push('second') + }, + { order: 10 } + ) + registry.addCallback( + BeforeInvocationEvent, + () => { + callOrder.push('third') + }, + { order: 10 } + ) + + await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent, invocationState: {} })) + + expect(callOrder).toEqual(['first', 'second', 'third']) + }) + + it('negative order runs before default', async () => { + const callOrder: string[] = [] + registry.addCallback(BeforeInvocationEvent, () => { + callOrder.push('default') + }) + registry.addCallback( + BeforeInvocationEvent, + () => { + callOrder.push('early') + }, + { order: -100 } + ) + + await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent, invocationState: {} })) + + expect(callOrder).toEqual(['early', 'default']) + }) + + it('After events: lower order still runs first across groups', async () => { + const callOrder: string[] = [] + registry.addCallback( + AfterInvocationEvent, + () => { + callOrder.push('early') + }, + { order: -100 } + ) + registry.addCallback( + AfterInvocationEvent, + () => { + callOrder.push('late') + }, + { order: 100 } + ) + + await registry.invokeCallbacks(new AfterInvocationEvent({ agent: mockAgent, invocationState: {} })) + + expect(callOrder).toEqual(['early', 'late']) + }) + }) + + describe('addCallback cleanup function', () => { + it('returns cleanup function that removes the callback', async () => { + const callback = vi.fn() + + const cleanup = registry.addCallback(BeforeInvocationEvent, callback) + cleanup() + + await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent, invocationState: {} })) + + expect(callback).not.toHaveBeenCalled() + }) + + it('cleanup function is idempotent', async () => { + const callback = vi.fn() + + const cleanup = registry.addCallback(BeforeInvocationEvent, callback) + cleanup() + cleanup() + cleanup() + + await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent, invocationState: {} })) + + expect(callback).not.toHaveBeenCalled() + }) + + it('cleanup function does not affect other callbacks', async () => { + const callback1 = vi.fn() + const callback2 = vi.fn() + + const cleanup1 = registry.addCallback(BeforeInvocationEvent, callback1) + registry.addCallback(BeforeInvocationEvent, callback2) + cleanup1() + + await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent, invocationState: {} })) + + expect(callback1).not.toHaveBeenCalled() + expect(callback2).toHaveBeenCalledOnce() + }) + + it('allows callback to be re-registered after cleanup', async () => { + const callback = vi.fn() + + const cleanup = registry.addCallback(BeforeInvocationEvent, callback) + cleanup() + + registry.addCallback(BeforeInvocationEvent, callback) + await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent, invocationState: {} })) + + expect(callback).toHaveBeenCalledTimes(1) + }) + + it('cleanup from one registration does not affect independent registration of same function', async () => { + const callback1 = vi.fn() + const callback2 = vi.fn() + + registry.addCallback(BeforeInvocationEvent, callback1) + const cleanup2 = registry.addCallback(BeforeInvocationEvent, callback2) + cleanup2() + + await registry.invokeCallbacks(new BeforeInvocationEvent({ agent: mockAgent, invocationState: {} })) + + expect(callback1).toHaveBeenCalledOnce() + expect(callback2).not.toHaveBeenCalled() + }) + }) + + describe('InterruptError collection', () => { + const createEvent = () => + new BeforeToolCallEvent({ + agent: mockAgent, + toolUse: { name: 'test', toolUseId: 'tool-1', input: {} }, + tool: undefined, + invocationState: {}, + }) + + it('collects InterruptErrors from multiple callbacks and invokes all of them', async () => { + const event = createEvent() + + const callback1 = vi.fn(() => { + event.interrupt({ name: 'interrupt_a', reason: 'Reason A' }) + }) + const callback2 = vi.fn(() => { + event.interrupt({ name: 'interrupt_b', reason: 'Reason B' }) + }) + + registry.addCallback(BeforeToolCallEvent, callback1) + registry.addCallback(BeforeToolCallEvent, callback2) + + await expect(registry.invokeCallbacks(event)).rejects.toThrow(InterruptError) + + expect(callback1).toHaveBeenCalledOnce() + expect(callback2).toHaveBeenCalledOnce() + + const state = getInterruptState(mockAgent) + expect(Object.keys(state.interrupts).length).toBe(2) + expect( + state + .getInterruptsList() + .map((i) => i.name) + .sort() + ).toEqual(['interrupt_a', 'interrupt_b']) + }) + + it('throws InterruptError with all collected interrupts after all callbacks run', async () => { + const event = createEvent() + + registry.addCallback(BeforeToolCallEvent, () => { + event.interrupt({ name: 'first', reason: 'First' }) + }) + registry.addCallback(BeforeToolCallEvent, () => { + event.interrupt({ name: 'second', reason: 'Second' }) + }) + + try { + await registry.invokeCallbacks(event) + expect.unreachable('should have thrown') + } catch (error) { + expect(error).toBeInstanceOf(InterruptError) + const ie = error as InterruptError + expect(ie.interrupts).toHaveLength(2) + expect(ie.interrupts.map((i) => i.name)).toEqual(['first', 'second']) + expect(ie.message).toBe('2 interrupts raised: first, second') + } + }) + + it('stops on non-interrupt error even when interrupts were collected', async () => { + const event = createEvent() + const callback3 = vi.fn() + + registry.addCallback(BeforeToolCallEvent, () => { + event.interrupt({ name: 'interrupt_a', reason: 'Reason A' }) + }) + registry.addCallback(BeforeToolCallEvent, () => { + throw new Error('Non-interrupt failure') + }) + registry.addCallback(BeforeToolCallEvent, callback3) + + await expect(registry.invokeCallbacks(event)).rejects.toThrow('Non-interrupt failure') + expect(callback3).not.toHaveBeenCalled() + }) + + it('runs all callbacks when only some raise interrupts', async () => { + const event = createEvent() + const callOrder: string[] = [] + + registry.addCallback(BeforeToolCallEvent, () => { + callOrder.push('first') + event.interrupt({ name: 'interrupt_a', reason: 'Reason A' }) + }) + registry.addCallback(BeforeToolCallEvent, () => { + callOrder.push('second-no-interrupt') + }) + registry.addCallback(BeforeToolCallEvent, () => { + callOrder.push('third') + event.interrupt({ name: 'interrupt_b', reason: 'Reason B' }) + }) + + await expect(registry.invokeCallbacks(event)).rejects.toThrow(InterruptError) + expect(callOrder).toEqual(['first', 'second-no-interrupt', 'third']) + + const state = getInterruptState(mockAgent) + expect(Object.keys(state.interrupts).length).toBe(2) + expect( + state + .getInterruptsList() + .map((i) => i.name) + .sort() + ).toEqual(['interrupt_a', 'interrupt_b']) + }) + + it('throws when two callbacks use the same interrupt name', async () => { + const event = createEvent() + + registry.addCallback(BeforeToolCallEvent, () => { + event.interrupt({ name: 'confirm', reason: 'First' }) + }) + registry.addCallback(BeforeToolCallEvent, () => { + event.interrupt({ name: 'confirm', reason: 'Second' }) + }) + + await expect(registry.invokeCallbacks(event)).rejects.toThrow( + 'interrupt_names= | duplicate interrupt names' + ) + }) + + it('reports all duplicate interrupt names in error', async () => { + const event = createEvent() + + registry.addCallback(BeforeToolCallEvent, () => { + event.interrupt({ name: 'alpha' }) + }) + registry.addCallback(BeforeToolCallEvent, () => { + event.interrupt({ name: 'alpha' }) + }) + registry.addCallback(BeforeToolCallEvent, () => { + event.interrupt({ name: 'beta' }) + }) + registry.addCallback(BeforeToolCallEvent, () => { + event.interrupt({ name: 'beta' }) + }) + + await expect(registry.invokeCallbacks(event)).rejects.toThrow( + 'interrupt_names= | duplicate interrupt names' + ) + }) + }) +}) diff --git a/strands-ts/src/hooks/events.ts b/strands-ts/src/hooks/events.ts new file mode 100644 index 0000000000..d682a91d50 --- /dev/null +++ b/strands-ts/src/hooks/events.ts @@ -0,0 +1,817 @@ +import type { LocalAgent, AgentResult, InvocationState, InvokeArgs } from '../types/agent.js' +import type { ContentBlock, Message, StopReason, ToolResultBlock } from '../types/messages.js' +import { type Tool, ToolStreamEvent } from '../tools/tool.js' +import type { JSONValue } from '../types/json.js' +import type { ModelStreamEvent } from '../models/streaming.js' +import type { Model } from '../models/model.js' +import { interruptFromAgent, type Interrupt, type Interruptible } from '../interrupt.js' +import type { InterruptParams } from '../types/interrupt.js' + +/** + * Agent hook events. + * + * All events extend {@link StreamEvent} with a `readonly type` discriminator + * (camelCase of the class name) for switch-based narrowing. Constructor takes + * a single data-object parameter. Most properties are readonly — writable fields + * are the hook-driven control/data fields documented per event + * (e.g. `cancel`, `retry`, `selectedTool`, `resume`, and mutable `toolUse` / `result`). + * + * All current events extend {@link HookableEvent} (which itself extends {@link StreamEvent}), + * making them both streamable and subscribable via hook callbacks. {@link StreamEvent} exists + * as the base class for potential future events that should be stream-only without hookability. + * + * ## Event categories + * + * **Lifecycle events** — Before/After pairs that bracket agent operations. + * - Naming: `BeforeEvent` / `AfterEvent` + * - `After*` events override `_shouldReverseCallbacks()` → `true` for cleanup ordering. + * - Examples: {@link BeforeInvocationEvent}/{@link AfterInvocationEvent}, + * {@link BeforeModelCallEvent}/{@link AfterModelCallEvent}, + * {@link BeforeToolsEvent}/{@link AfterToolsEvent}, + * {@link BeforeToolCallEvent}/{@link AfterToolCallEvent} + * + * **State-change events** — Signal that agent state was mutated. + * - Naming: `Event` + * - Examples: {@link InitializedEvent}, {@link MessageAddedEvent} + * + * **Data events** — Wrap data objects produced during agent execution. + * Two sub-categories: + * + * *Update events* — wrap transient streaming data from lower layers. + * - Naming: `StreamUpdateEvent`, payload field: `.event` + * - Examples: {@link ModelStreamUpdateEvent}, {@link ToolStreamUpdateEvent} + * + * *Completion events* — wrap finished data after processing completes. + * - Naming: descriptive `Event`, payload field matches data type + * (`.result` for results, `.message` for messages, `.contentBlock` for content blocks). + * - Examples: {@link ContentBlockEvent}, {@link ModelMessageEvent}, + * {@link ToolResultEvent}, {@link AgentResultEvent} + * + * ## Field naming conventions + * + * | Field | Usage | + * |--------------------|--------------------------------------------------| + * | `agent` | `LocalAgent` reference on all agent-loop events | + * | `invocationState` | Per-invocation state — see below | + * | `.event` | Inner event in update wrappers | + * | `.result` | Finished result object | + * | `.message` | Message object | + * | `.contentBlock` | Content block object | + * + * ## `invocationState` on events + * + * Every hookable event that fires **during** an invocation carries + * {@link InvocationState} — the per-invocation mutable bag shared across hooks + * and tools. This lets any callback (lifecycle, data, streaming, completion) + * correlate back to the caller's request context (`userId`, `traceId`, etc.) + * without closure workarounds. + * + * The only events without `invocationState` are the ones that fire **outside** + * any invocation scope: {@link InitializedEvent} and `MultiAgentInitializedEvent`, + * both of which fire at construction. + * + * New events should follow the same rule: carry `invocationState` unless the + * event fires before any invocation exists. + */ + +/** + * Base class for all events yielded by `agent.stream()`. + * Carries no hookability — subclasses that should be hookable extend {@link HookableEvent} instead. + */ +export abstract class StreamEvent {} + +/** + * Base class for events that can be subscribed to via the hook system. + * Only events extending this class are dispatched to {@link HookRegistry} callbacks. + * All current events extend this class. {@link StreamEvent} exists as the base for + * potential future stream-only events that should not be hookable. + */ +export abstract class HookableEvent extends StreamEvent { + /** + * @internal + * Check if callbacks should be reversed for this event. + * Used by HookRegistry for callback ordering. + */ + _shouldReverseCallbacks(): boolean { + return false + } +} + +/** + * Mutable tool-use descriptor carried on tool-call hook events. + * Matches the shape of the tool use block the model emitted; hooks on + * {@link BeforeToolCallEvent} may mutate its fields (or reassign the object) + * to rewrite the input, id, or tool name before the tool executes. + */ +export interface ToolUseData { + name: string + toolUseId: string + input: JSONValue +} + +/** + * Event triggered when an agent has finished initialization. + * Fired after the agent has been fully constructed and all built-in components have been initialized. + */ +export class InitializedEvent extends HookableEvent { + readonly type = 'initializedEvent' as const + readonly agent: LocalAgent + + constructor(data: { agent: LocalAgent }) { + super() + this.agent = data.agent + } + + /** + * Serializes for wire transport, excluding the agent reference. + * Called automatically by JSON.stringify(). + */ + toJSON(): Pick { + return { type: this.type } + } +} + +/** + * Event triggered at the beginning of a new agent request. + * Fired before any model inference or tool execution occurs. + */ +export class BeforeInvocationEvent extends HookableEvent { + readonly type = 'beforeInvocationEvent' as const + readonly agent: LocalAgent + readonly invocationState: InvocationState + + /** + * Set by hook callbacks to cancel this invocation. + * When set to `true`, a default cancel message is used. + * When set to a string, that string is used as the assistant response message. + */ + cancel: boolean | string = false + + constructor(data: { agent: LocalAgent; invocationState: InvocationState }) { + super() + this.agent = data.agent + this.invocationState = data.invocationState + } + + /** + * Serializes for wire transport, excluding the agent reference and invocationState. + * Called automatically by JSON.stringify(). + */ + toJSON(): Pick { + return { type: this.type } + } +} + +/** + * Event triggered at the end of an agent request. + * Fired after all processing completes, regardless of success or error. + * Uses reverse callback ordering for proper cleanup semantics. + */ +export class AfterInvocationEvent extends HookableEvent { + readonly type = 'afterInvocationEvent' as const + readonly agent: LocalAgent + readonly invocationState: InvocationState + + /** + * Set by hook callbacks to trigger a follow-up agent invocation with new input. + * When set, after this event's callbacks complete the agent re-enters its loop + * with these args as new input, under the same invocation lock. A fresh + * {@link BeforeInvocationEvent}/{@link AfterInvocationEvent} pair fires for the + * resumed run. Ignored if the invocation ended with an error. + * + * If multiple callbacks set `resume`, the last callback to run wins. + */ + resume: InvokeArgs | undefined = undefined + + constructor(data: { agent: LocalAgent; invocationState: InvocationState }) { + super() + this.agent = data.agent + this.invocationState = data.invocationState + } + + override _shouldReverseCallbacks(): boolean { + return true + } + + /** + * Serializes for wire transport, excluding the agent reference, invocationState, + * and mutable resume field. + * Called automatically by JSON.stringify(). + */ + toJSON(): Pick { + return { type: this.type } + } +} + +/** + * Event triggered when the framework adds a message to the conversation history. + * Fired for user input, assistant responses, and tool-result messages added + * during agent execution. Does not fire for messages preloaded via + * `AgentConfig.messages` or messages manually pushed to `agent.messages`. + */ +export class MessageAddedEvent extends HookableEvent { + readonly type = 'messageAddedEvent' as const + readonly agent: LocalAgent + readonly message: Message + readonly invocationState: InvocationState + + constructor(data: { agent: LocalAgent; message: Message; invocationState: InvocationState }) { + super() + this.agent = data.agent + this.message = data.message + this.invocationState = data.invocationState + } + + /** + * Serializes for wire transport, excluding the agent reference and invocationState. + * Called automatically by JSON.stringify(). + */ + toJSON(): Pick { + return { type: this.type, message: this.message } + } +} + +/** + * Event triggered just before a tool is executed. + * Fired after tool lookup but before execution begins. + * + * Hook callbacks can: + * - Set {@link cancel} to prevent the tool from executing. + * - Set {@link selectedTool} to execute a different tool in place of the registry's match. + * - Mutate {@link toolUse} to rewrite the tool input, id, or name before execution. + * If `name` is changed and `selectedTool` is not set, the tool is re-resolved from + * the registry under the new name. + */ +export class BeforeToolCallEvent extends HookableEvent implements Interruptible { + readonly type = 'beforeToolCallEvent' as const + readonly agent: LocalAgent + toolUse: ToolUseData + readonly tool: Tool | undefined + readonly invocationState: InvocationState + + /** + * Set by hook callbacks to cancel this tool call. + * When set to `true`, a default cancel message is used. + * When set to a string, that string is used as the tool result error message. + */ + cancel: boolean | string = false + + /** + * Set by hook callbacks to execute a replacement tool instead of {@link tool}. + * When undefined, the tool looked up from the registry (or re-resolved from a + * mutated `toolUse.name`) is used. + * + * If multiple callbacks set `selectedTool`, the last callback to run wins. + * Callbacks run in registration order for this event, so the last-registered + * callback's value is the one used. + */ + selectedTool: Tool | undefined = undefined + + constructor(data: { + agent: LocalAgent + toolUse: ToolUseData + tool: Tool | undefined + invocationState: InvocationState + }) { + super() + this.agent = data.agent + this.toolUse = data.toolUse + this.tool = data.tool + this.invocationState = data.invocationState + } + + /** + * Raises an interrupt for human-in-the-loop workflows. + * If a response is available (from a previous resume), returns it immediately. + * Otherwise, throws an InterruptError to halt agent execution. + * + * @param params - Interrupt parameters including name and optional reason + * @returns The user's response when resuming from an interrupt + */ + interrupt(params: InterruptParams): T { + return interruptFromAgent( + this.agent, + `hook:beforeToolCall:${this.toolUse.toolUseId}:${params.name}`, + params, + 'hook' + ) + } + + /** + * Serializes for wire transport, excluding the agent reference, tool instance, + * invocationState, and mutable cancel / selectedTool fields. + * Called automatically by JSON.stringify(). + */ + toJSON(): Pick { + return { type: this.type, toolUse: this.toolUse } + } +} + +/** + * Event triggered after a tool execution completes. + * Fired after tool execution finishes, whether successful or failed. + * Uses reverse callback ordering for proper cleanup semantics. + * + * Hook callbacks can mutate {@link result} to rewrite the tool result before it + * propagates to the model (e.g. to redact or truncate output). + */ +export class AfterToolCallEvent extends HookableEvent { + readonly type = 'afterToolCallEvent' as const + readonly agent: LocalAgent + readonly toolUse: ToolUseData + readonly tool: Tool | undefined + + /** + * The tool result. Can be replaced by hook callbacks to transform the result + * before it enters the conversation history. + */ + result: ToolResultBlock + + readonly error?: Error + readonly invocationState: InvocationState + + /** + * Optional flag that can be set by hook callbacks to request a retry of the tool call. + * When set to true, the agent will re-execute the tool. + */ + retry?: boolean + + constructor(data: { + agent: LocalAgent + toolUse: ToolUseData + tool: Tool | undefined + result: ToolResultBlock + invocationState: InvocationState + error?: Error + }) { + super() + this.agent = data.agent + this.toolUse = data.toolUse + this.tool = data.tool + this.result = data.result + this.invocationState = data.invocationState + if (data.error !== undefined) { + this.error = data.error + } + } + + override _shouldReverseCallbacks(): boolean { + return true + } + + /** + * Serializes for wire transport, excluding the agent reference, tool instance, invocationState, and mutable retry flag. + * Converts Error to an extensible object for safe wire serialization. + * Called automatically by JSON.stringify(). + */ + toJSON(): Pick & { error?: { message?: string } } { + return { + type: this.type, + toolUse: this.toolUse, + result: this.result, + ...(this.error !== undefined && { error: { message: this.error.message } }), + } + } +} + +/** + * Event triggered just before the model is invoked. + * Fired before sending messages to the model for inference. + */ +export class BeforeModelCallEvent extends HookableEvent { + readonly type = 'beforeModelCallEvent' as const + readonly agent: LocalAgent + readonly model: Model + readonly invocationState: InvocationState + + /** + * Set by hook callbacks to cancel this model call. + * When set to `true`, a default cancel message is used. + * When set to a string, that string is used as the assistant response message. + */ + cancel: boolean | string = false + + /** + * Projected input token count for the upcoming model call. + * Computed by the agent loop from message metadata and token estimation. + * Available for hooks and plugins (e.g. conversation managers) to make + * proactive decisions about context management. + */ + readonly projectedInputTokens?: number + + constructor(data: { + agent: LocalAgent + model: Model + invocationState: InvocationState + projectedInputTokens?: number + }) { + super() + this.agent = data.agent + this.model = data.model + this.invocationState = data.invocationState + if (data.projectedInputTokens !== undefined) { + this.projectedInputTokens = data.projectedInputTokens + } + } + + /** + * Serializes for wire transport, excluding the agent reference and invocationState. + * Called automatically by JSON.stringify(). + */ + toJSON(): Pick { + return { + type: this.type, + ...(this.projectedInputTokens !== undefined && { projectedInputTokens: this.projectedInputTokens }), + } + } +} + +/** + * Redaction information when guardrails block content. + */ +export interface Redaction { + /** + * The text to replace the user message with. + * When present, indicates the last user message should be redacted with this text. + */ + userMessage: string +} + +/** + * Response from a model invocation containing the message and stop reason. + */ +export interface ModelStopData { + /** + * The message returned by the model. + */ + readonly message: Message + /** + * The reason the model stopped generating. + */ + readonly stopReason: StopReason + /** + * Optional redaction info when guardrails blocked input. + * When present, indicates the last user message was redacted. + * The redacted message is available in `agent.messages` (last message). + */ + readonly redaction?: Redaction +} + +/** + * Event triggered after the model invocation completes. + * Fired after the model finishes generating a response, whether successful or failed. + * Uses reverse callback ordering for proper cleanup semantics. + * + * Note: stopData may be undefined if an error occurs before the model completes. + */ +export class AfterModelCallEvent extends HookableEvent { + readonly type = 'afterModelCallEvent' as const + readonly agent: LocalAgent + readonly model: Model + readonly stopData?: ModelStopData + readonly error?: Error + readonly invocationState: InvocationState + + /** + * 1-indexed count of model attempts for this turn, including the attempt + * that just completed (or failed). The first call in a turn is `1`; each + * subsequent retry increments by one. + * + * Retry strategies may rely on `attemptCount === 1` to mark the start of a + * new retry sequence (e.g. to clear per-turn state carried over from a + * previous turn). The agent loop guarantees this marker on every fresh turn. + */ + readonly attemptCount: number + + /** + * Optional flag that can be set by hook callbacks to request a retry of the model call. + * When set to true, the agent will retry the model invocation. + */ + retry?: boolean + + constructor(data: { + agent: LocalAgent + model: Model + invocationState: InvocationState + attemptCount: number + stopData?: ModelStopData + error?: Error + }) { + super() + this.agent = data.agent + this.model = data.model + this.invocationState = data.invocationState + this.attemptCount = data.attemptCount + if (data.stopData !== undefined) { + this.stopData = data.stopData + } + if (data.error !== undefined) { + this.error = data.error + } + } + + override _shouldReverseCallbacks(): boolean { + return true + } + + /** + * Serializes for wire transport, excluding the agent reference, invocationState, and mutable retry flag. + * Converts Error to an extensible object for safe wire serialization. + * Called automatically by JSON.stringify(). + */ + toJSON(): Pick & { error?: { message?: string } } { + return { + type: this.type, + attemptCount: this.attemptCount, + ...(this.stopData !== undefined && { stopData: this.stopData }), + ...(this.error !== undefined && { error: { message: this.error.message } }), + } + } +} + +/** + * Event triggered for each streaming event from the model. + * Wraps a {@link ModelStreamEvent} (transient streaming delta) during model inference. + * Completed content blocks are handled separately by {@link ContentBlockEvent} + * because they represent different granularities: partial deltas vs fully assembled results. + */ +export class ModelStreamUpdateEvent extends HookableEvent { + readonly type = 'modelStreamUpdateEvent' as const + readonly agent: LocalAgent + readonly event: ModelStreamEvent + readonly invocationState: InvocationState + + constructor(data: { agent: LocalAgent; event: ModelStreamEvent; invocationState: InvocationState }) { + super() + this.agent = data.agent + this.event = data.event + this.invocationState = data.invocationState + } + + /** + * Serializes for wire transport, excluding the agent reference and invocationState. + * Called automatically by JSON.stringify(). + */ + toJSON(): Pick { + return { type: this.type, event: this.event } + } +} + +/** + * Event triggered when a content block completes during model inference. + * Wraps completed content blocks (TextBlock, ToolUseBlock, ReasoningBlock) from model streaming. + * This is intentionally separate from {@link ModelStreamUpdateEvent}. The model's + * `streamAggregated()` yields two kinds of output: {@link ModelStreamEvent} (transient + * streaming deltas — partial data arriving while the model generates) and + * {@link ContentBlock} (fully assembled results after all deltas accumulate). + * These represent different granularities with different semantics, so they are + * wrapped in distinct event classes rather than combined into a single event. + */ +export class ContentBlockEvent extends HookableEvent { + readonly type = 'contentBlockEvent' as const + readonly agent: LocalAgent + readonly contentBlock: ContentBlock + readonly invocationState: InvocationState + + constructor(data: { agent: LocalAgent; contentBlock: ContentBlock; invocationState: InvocationState }) { + super() + this.agent = data.agent + this.contentBlock = data.contentBlock + this.invocationState = data.invocationState + } + + /** + * Serializes for wire transport, excluding the agent reference and invocationState. + * Called automatically by JSON.stringify(). + */ + toJSON(): Pick { + return { type: this.type, contentBlock: this.contentBlock } + } +} + +/** + * Event triggered when the model completes a full message. + * Wraps the assembled message and stop reason after model streaming finishes. + */ +export class ModelMessageEvent extends HookableEvent { + readonly type = 'modelMessageEvent' as const + readonly agent: LocalAgent + readonly message: Message + readonly stopReason: StopReason + readonly invocationState: InvocationState + + constructor(data: { agent: LocalAgent; message: Message; stopReason: StopReason; invocationState: InvocationState }) { + super() + this.agent = data.agent + this.message = data.message + this.stopReason = data.stopReason + this.invocationState = data.invocationState + } + + /** + * Serializes for wire transport, excluding the agent reference and invocationState. + * Called automatically by JSON.stringify(). + */ + toJSON(): Pick { + return { type: this.type, message: this.message, stopReason: this.stopReason } + } +} + +/** + * Event triggered when a tool execution completes. + * Wraps the tool result block after a tool finishes execution. + */ +export class ToolResultEvent extends HookableEvent { + readonly type = 'toolResultEvent' as const + readonly agent: LocalAgent + readonly result: ToolResultBlock + readonly invocationState: InvocationState + + constructor(data: { agent: LocalAgent; result: ToolResultBlock; invocationState: InvocationState }) { + super() + this.agent = data.agent + this.result = data.result + this.invocationState = data.invocationState + } + + /** + * Serializes for wire transport, excluding the agent reference and invocationState. + * Called automatically by JSON.stringify(). + */ + toJSON(): Pick { + return { type: this.type, result: this.result } + } +} + +/** + * Event triggered for each streaming progress event from a tool during execution. + * Wraps a {@link ToolStreamEvent} with agent context, keeping the tool authoring + * interface unchanged — tools construct `ToolStreamEvent` without knowledge of agents + * or hooks, and the agent layer wraps them at the boundary. + * + * Consistent with {@link ModelStreamUpdateEvent} which wraps model streaming events + * the same way. + */ +export class ToolStreamUpdateEvent extends HookableEvent { + readonly type = 'toolStreamUpdateEvent' as const + readonly agent: LocalAgent + readonly event: ToolStreamEvent + readonly invocationState: InvocationState + + constructor(data: { agent: LocalAgent; event: ToolStreamEvent; invocationState: InvocationState }) { + super() + this.agent = data.agent + this.event = data.event + this.invocationState = data.invocationState + } + + /** + * Serializes for wire transport, excluding the agent reference and invocationState. + * Called automatically by JSON.stringify(). + */ + toJSON(): Pick { + return { type: this.type, event: this.event } + } +} + +/** + * Event triggered as the final event in the agent stream. + * Wraps the agent result containing the stop reason and last message. + */ +export class AgentResultEvent extends HookableEvent { + readonly type = 'agentResultEvent' as const + readonly agent: LocalAgent + readonly result: AgentResult + readonly invocationState: InvocationState + + constructor(data: { agent: LocalAgent; result: AgentResult; invocationState: InvocationState }) { + super() + this.agent = data.agent + this.result = data.result + this.invocationState = data.invocationState + } + + /** + * Serializes for wire transport, excluding the agent reference and invocationState. + * Called automatically by JSON.stringify(). + */ + toJSON(): Pick { + return { type: this.type, result: this.result } + } +} + +/** + * Event emitted when an interrupt is raised during agent execution. The `interrupt.source` + * field discriminates between tool-callback and hook-callback origins. One event fires + * per unanswered interrupt at the moment the agent stops to wait for responses. + */ +export class InterruptEvent extends HookableEvent { + readonly type = 'interruptEvent' as const + readonly agent: LocalAgent + readonly interrupt: Interrupt + readonly invocationState: InvocationState + + constructor(data: { agent: LocalAgent; interrupt: Interrupt; invocationState: InvocationState }) { + super() + this.agent = data.agent + this.interrupt = data.interrupt + this.invocationState = data.invocationState + } + + /** Serializes for wire transport, excluding agent and invocationState. */ + toJSON(): Pick & { interrupt: ReturnType } { + return { type: this.type, interrupt: this.interrupt.toJSON() } + } +} + +/** + * Event triggered before executing tools. + * Fired when the model returns tool use blocks that need to be executed. + * Hook callbacks can set {@link cancel} to prevent all tools from executing. + */ +export class BeforeToolsEvent extends HookableEvent implements Interruptible { + readonly type = 'beforeToolsEvent' as const + readonly agent: LocalAgent + readonly message: Message + readonly invocationState: InvocationState + + /** + * Set by hook callbacks to cancel all tool calls. + * When set to `true`, a default cancel message is used. + * When set to a string, that string is used as the tool result error message. + */ + cancel: boolean | string = false + + constructor(data: { agent: LocalAgent; message: Message; invocationState: InvocationState }) { + super() + this.agent = data.agent + this.message = data.message + this.invocationState = data.invocationState + } + + /** + * Raises an interrupt for human-in-the-loop workflows. + * If a response is available (from a previous resume), returns it immediately. + * Otherwise, throws an InterruptError to halt agent execution. + * + * @param params - Interrupt parameters including name and optional reason + * @returns The user's response when resuming from an interrupt + */ + interrupt(params: InterruptParams): T { + return interruptFromAgent(this.agent, `hook:beforeTools:${params.name}`, params, 'hook') + } + + /** + * Serializes for wire transport, excluding the agent reference, invocationState, and mutable cancel flag. + * Called automatically by JSON.stringify(). + */ + toJSON(): Pick { + return { type: this.type, message: this.message } + } +} + +/** + * Event triggered after all tools complete execution. + * Fired after tool results are collected and ready to be added to conversation. + * Uses reverse callback ordering for proper cleanup semantics. + */ +export class AfterToolsEvent extends HookableEvent { + readonly type = 'afterToolsEvent' as const + readonly agent: LocalAgent + readonly message: Message + readonly invocationState: InvocationState + + /** + * When set to `true`, the agent loop halts after this tool batch completes + * without calling the model again and a default message + * (`"Turn ended early by hook after tool execution"`) is appended as the + * final assistant message. When set to a string, that string is used instead + * of the default — the string becomes literal assistant content (a + * `TextBlock`), not a reason or label. Contrast with + * {@link BeforeToolCallEvent.cancel | cancel} fields on other events, where + * the string is a cancellation reason. + * + * In both cases `stopReason` on the returned `AgentResult` is `'endTurn'`. + */ + endTurn: boolean | string = false + + constructor(data: { agent: LocalAgent; message: Message; invocationState: InvocationState }) { + super() + this.agent = data.agent + this.message = data.message + this.invocationState = data.invocationState + } + + override _shouldReverseCallbacks(): boolean { + return true + } + + /** + * Serializes for wire transport, excluding the agent reference, invocationState, + * and mutable endTurn field. + * Called automatically by JSON.stringify(). + */ + toJSON(): Pick { + return { type: this.type, message: this.message } + } +} diff --git a/strands-ts/src/hooks/index.ts b/strands-ts/src/hooks/index.ts new file mode 100644 index 0000000000..b93496c05b --- /dev/null +++ b/strands-ts/src/hooks/index.ts @@ -0,0 +1,47 @@ +/** + * Hooks module for event-driven extensibility. + * + * This module has two concerns with distinct naming: + * + * - **Events** (`StreamEvent` and subclasses) — the data objects yielded by `agent.stream()`. + * Named `Stream*` because they are members of the agent stream. + * All current events extend {@link HookableEvent}, making them subscribable via hook callbacks. + * See {@link StreamEvent} and `events.ts` for the full taxonomy. + * + * - **Hook infrastructure** (`HookCallback`, `HookRegistry`, `HookCleanup`) — + * the subscription mechanism that lets callers register callbacks for {@link HookableEvent} types. + * Named `Hook*` because they describe the hooking/subscription pattern, not the events themselves. + */ + +// Event classes +export { + StreamEvent, + HookableEvent, + InitializedEvent, + BeforeInvocationEvent, + AfterInvocationEvent, + MessageAddedEvent, + BeforeToolCallEvent, + AfterToolCallEvent, + BeforeModelCallEvent, + AfterModelCallEvent, + ModelStreamUpdateEvent, + ContentBlockEvent, + ModelMessageEvent, + ToolResultEvent, + ToolStreamUpdateEvent, + AgentResultEvent, + InterruptEvent, + BeforeToolsEvent, + AfterToolsEvent, +} from './events.js' + +// Event types +export type { ModelStopData as ModelStopResponse, Redaction, ToolUseData } from './events.js' + +// Registry +export { HookRegistryImplementation as HookRegistry } from './registry.js' + +// Types +export type { HookCallback, HookableEventConstructor, HookCallbackOptions, HookCleanup } from './types.js' +export { HookOrder } from './types.js' diff --git a/strands-ts/src/hooks/registry.ts b/strands-ts/src/hooks/registry.ts new file mode 100644 index 0000000000..a47717d91b --- /dev/null +++ b/strands-ts/src/hooks/registry.ts @@ -0,0 +1,137 @@ +import type { HookableEvent } from './events.js' +import { HookOrder } from './types.js' +import type { HookCallback, HookableEventConstructor, HookCallbackOptions, HookCleanup } from './types.js' +import { InterruptError, Interrupt } from '../interrupt.js' + +/** + * Represents a registered callback entry. + */ +type CallbackEntry = { + callback: HookCallback + order: number +} + +/** + * Interface for hook registry operations. + * Enables registration of hook callbacks for event-driven extensibility. + */ +export interface HookRegistry { + /** + * Register a callback function for a specific event type. + * + * @param eventType - The event class constructor to register the callback for + * @param callback - The callback function to invoke when the event occurs + * @param options - Optional configuration including execution order + * @returns Cleanup function that removes the callback when invoked + */ + addCallback( + eventType: HookableEventConstructor, + callback: HookCallback, + options?: HookCallbackOptions + ): HookCleanup +} + +/** + * Implementation of the hook registry for managing hook callbacks. + * Maintains mappings between event types and callback functions. + */ +export class HookRegistryImplementation implements HookRegistry { + private readonly _callbacks: Map + + constructor() { + this._callbacks = new Map() + } + + /** {@inheritDoc HookRegistry.addCallback} */ + addCallback( + eventType: HookableEventConstructor, + callback: HookCallback, + options?: HookCallbackOptions + ): HookCleanup { + const entry: CallbackEntry = { + callback: callback as HookCallback, + order: options?.order ?? HookOrder.DEFAULT, + } + const callbacks = this._callbacks.get(eventType) ?? [] + // Insert in sorted position: lower order first, same order preserves registration order + const insertAt = callbacks.findIndex((e) => e.order > entry.order) + if (insertAt === -1) { + callbacks.push(entry) + } else { + callbacks.splice(insertAt, 0, entry) + } + this._callbacks.set(eventType, callbacks) + + return () => { + const callbacks = this._callbacks.get(eventType) + if (!callbacks) return + const index = callbacks.indexOf(entry) + if (index !== -1) { + callbacks.splice(index, 1) + } + } + } + + /** + * Invoke all registered callbacks for the given event. + * Awaits each callback, supporting both sync and async. + * + * InterruptErrors are collected across callbacks rather than immediately thrown, + * allowing all hooks to register their interrupts. Non-interrupt errors propagate immediately. + * + * @param event - The event to invoke callbacks for + * @returns The event after all callbacks have been invoked + * @throws InterruptError with all collected interrupts after all callbacks complete + */ + async invokeCallbacks(event: T): Promise { + const callbacks = this.getCallbacksFor(event) + const collectedInterrupts: Interrupt[] = [] + + for (const callback of callbacks) { + try { + await callback(event) + } catch (error) { + if (error instanceof InterruptError) { + collectedInterrupts.push(...error.interrupts) + } else { + throw error + } + } + } + + if (collectedInterrupts.length > 0) { + const seen = new Set() + const duplicates = new Set() + for (const interrupt of collectedInterrupts) { + if (seen.has(interrupt.name)) { + duplicates.add(interrupt.name) + } + seen.add(interrupt.name) + } + if (duplicates.size > 0) { + const names = [...duplicates].join(', ') + throw new Error(`interrupt_names=<${names}> | duplicate interrupt names`) + } + throw new InterruptError(collectedInterrupts) + } + + return event + } + + /** + * Get callbacks for a specific event in order. + * For After* events, reverses then re-sorts by order so that lower order + * still runs first, but same-order hooks run in reverse registration order. + * + * @param event - The event to get callbacks for + * @returns Array of callbacks for the event + */ + private getCallbacksFor(event: T): HookCallback[] { + const entries = this._callbacks.get(event.constructor as HookableEventConstructor) ?? [] + if (event._shouldReverseCallbacks()) { + const reversed = [...entries].reverse().sort((a, b) => a.order - b.order) + return reversed.map((entry) => entry.callback) as HookCallback[] + } + return entries.map((entry) => entry.callback) as HookCallback[] + } +} diff --git a/strands-ts/src/hooks/types.ts b/strands-ts/src/hooks/types.ts new file mode 100644 index 0000000000..f1e17efdb2 --- /dev/null +++ b/strands-ts/src/hooks/types.ts @@ -0,0 +1,54 @@ +import type { HookableEvent } from './events.js' + +/** + * Type for a constructor function that creates HookableEvent instances. + */ +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export type HookableEventConstructor = new (...args: any[]) => T + +/** + * Type for callback functions that handle hookable events. + * Callbacks can be synchronous or asynchronous. + * + * @example + * ```typescript + * const callback: HookCallback = (event) => { + * console.log('Agent invocation started') + * } + * ``` + */ +export type HookCallback = (event: T) => void | Promise + +/** + * Options for registering a hook callback. + */ +export interface HookCallbackOptions { + order?: number +} + +/** + * Function that removes a previously registered hook callback. + * Safe to call multiple times (idempotent). + * No-op if the callback is no longer registered. + */ +export type HookCleanup = () => void + +/** + * Presets for hook execution order. Lower values run first. + * Any number is a valid order — these presets are not bounds, just convenient + * reference points. SDK_FIRST/SDK_LAST mark where the SDK's own hooks run, + * so you can position yours relative to them. + * + * @example + * ```typescript + * agent.addHook(BeforeToolCallEvent, callback, { order: HookOrder.SDK_FIRST }) // run with the SDK's earliest hooks + * agent.addHook(BeforeToolCallEvent, callback, { order: HookOrder.SDK_FIRST - 1 }) // run before the SDK's earliest hooks + * ``` + */ +export const HookOrder = { + SDK_FIRST: -100, + INTERVENTION_OUTPUT: -90, + DEFAULT: 0, + INTERVENTION_INPUT: 90, + SDK_LAST: 100, +} as const diff --git a/strands-ts/src/index.ts b/strands-ts/src/index.ts new file mode 100644 index 0000000000..52316b23d0 --- /dev/null +++ b/strands-ts/src/index.ts @@ -0,0 +1,313 @@ +/** + * Main entry point for the Strands Agents TypeScript SDK. + * + * This is the primary export module for the SDK, providing access to all + * public APIs and functionality. + */ + +// Agent class +export { Agent } from './agent/agent.js' + +// App state +export { StateStore } from './state-store.js' + +// Agent types +export { AgentResult } from './types/agent.js' +export type { AgentConfig, ToolList, ToolExecutorStrategy } from './agent/agent.js' +export type { AgentAsToolOptions } from './agent/agent-as-tool.js' +export type { ToolCaller, ToolCallerProxy, ToolHandle, DirectToolCallOptions } from './agent/tool-caller.js' +export type { InvocationState, InvokeArgs, InvokeOptions, LocalAgent } from './types/agent.js' +export type { LifecycleObserver } from './types/lifecycle-observer.js' + +// Snapshot types +export { SNAPSHOT_SCHEMA_VERSION } from './types/snapshot.js' +export type { Scope, Snapshot } from './types/snapshot.js' +export type { TakeSnapshotOptions, SnapshotField, SnapshotPreset } from './agent/snapshot.js' + +// Error types +// Note: CancelledError is intentionally not exported — it is an internal +// control-flow mechanism, never thrown to consumers. See its docstring in errors.ts. +export { + ModelError, + ContextWindowOverflowError, + MaxTokensError, + JsonValidationError, + ConcurrentInvocationError, + ModelThrottledError, + ToolValidationError, + StructuredOutputError, + ToolNotFoundError, +} from './errors.js' + +// Interrupt system +export type { Interrupt, InterruptSource } from './interrupt.js' +export type { InterruptParams, InterruptResponse, InterruptResponseContentData } from './types/interrupt.js' +export { InterruptResponseContent } from './types/interrupt.js' + +// JSON types +export type { JSONSchema, JSONValue } from './types/json.js' + +// Message types +export type { + Role, + StopReason, + TextBlockData, + ToolUseBlockData, + ToolResultBlockData, + ReasoningBlockData, + CachePointBlockData, + GuardContentBlockData, + GuardContentText, + GuardContentImage, + GuardQualifier, + GuardImageFormat, + GuardImageSource, + ContentBlock, + ContentBlockData, + MessageData, + SystemPrompt, + SystemPromptData, + SystemContentBlock, + ToolResultContent, +} from './types/messages.js' + +// Message classes +export { + TextBlock, + ToolUseBlock, + ToolResultBlock, + ReasoningBlock, + CachePointBlock, + GuardContentBlock, + Message, + JsonBlock, + contentBlockFromData, + toolResultContentFromData, +} from './types/messages.js' + +// Citation types +export type { + CitationsBlockData, + Citation, + CitationLocation, + CitationSourceContent, + CitationGeneratedContent, +} from './types/citations.js' + +// Citation class +export { CitationsBlock } from './types/citations.js' + +// Media classes +export { S3Location, ImageBlock, VideoBlock, DocumentBlock } from './types/media.js' + +// Media types +export type { + LocationData, + S3LocationData, + ImageFormat, + ImageSource, + ImageSourceData, + ImageBlockData, + VideoFormat, + VideoSource, + VideoSourceData, + VideoBlockData, + DocumentFormat, + DocumentSource, + DocumentSourceData, + DocumentBlockData, + DocumentContentBlock, + DocumentContentBlockData, +} from './types/media.js' + +// Tool types +export type { ToolSpec, ToolUse, ToolResultStatus, ToolChoice } from './tools/types.js' + +// Tool interface and related types +export type { InvokableTool, ToolContext, ToolStreamEventData, ToolStreamGenerator } from './tools/tool.js' + +// Tool base class and event classes +export { Tool, ToolStreamEvent } from './tools/tool.js' + +// FunctionTool implementation +export { FunctionTool } from './tools/function-tool.js' +export type { FunctionToolConfig, FunctionToolCallback } from './tools/function-tool.js' + +// ZodTool implementation +export { ZodTool } from './tools/zod-tool.js' +export type { ZodToolConfig } from './tools/zod-tool.js' + +// Tool factory function +export { tool } from './tools/tool-factory.js' + +// Streaming event types +export type { + Usage, + Metrics, + ModelMessageStartEventData, + ToolUseStart, + ContentBlockStart, + ModelContentBlockStartEventData, + TextDelta, + ToolUseInputDelta, + ReasoningContentDelta, + CitationsDelta, + ContentBlockDelta, + ModelContentBlockDeltaEventData, + ModelMessageStopEventData, + ModelMetadataEventData, + RedactInputContent, + RedactOutputContent, + ModelRedactionEventData, + ModelStreamEvent, +} from './models/streaming.js' + +// Streaming event classes (value exports for instanceof checks and custom model providers) +export { + isModelStreamEvent, + ModelMessageStartEvent, + ModelContentBlockStartEvent, + ModelContentBlockDeltaEvent, + ModelContentBlockStopEvent, + ModelMessageStopEvent, + ModelMetadataEvent, + ModelRedactionEvent, +} from './models/streaming.js' + +// Model provider types +export type { BaseModelConfig, CountTokensOptions, StreamOptions, CacheConfig } from './models/model.js' + +export { Model } from './models/model.js' + +// Bedrock model provider +export { BedrockModel as BedrockModel } from './models/bedrock.js' +export type { + BedrockModelConfig, + BedrockModelOptions, + BedrockGuardrailConfig, + BedrockGuardrailRedactionConfig, + BedrockCacheConfig, + BedrockCacheTTL, +} from './models/bedrock.js' + +// Agent streaming event types +export type { AgentStreamEvent } from './types/agent.js' + +// Hooks system +export { + HookRegistry, + HookOrder, + StreamEvent, + HookableEvent, + InitializedEvent, + BeforeInvocationEvent, + AfterInvocationEvent, + MessageAddedEvent, + BeforeToolCallEvent, + AfterToolCallEvent, + BeforeModelCallEvent, + AfterModelCallEvent, + BeforeToolsEvent, + AfterToolsEvent, + ContentBlockEvent, + ModelMessageEvent, + ToolResultEvent, + ToolStreamUpdateEvent, + AgentResultEvent, + InterruptEvent, + ModelStreamUpdateEvent, +} from './hooks/index.js' +export type { + HookCallback, + HookableEventConstructor, + HookCallbackOptions, + ModelStopResponse, + Redaction, + ToolUseData, +} from './hooks/index.js' + +// Plugin system +export type { Plugin } from './plugins/index.js' + +// Intervention system +export { InterventionHandler, InterventionActions } from './interventions/index.js' +export type { OnError } from './interventions/index.js' + +// Retry +export { + type BackoffContext, + type BackoffStrategy, + type JitterKind, + type ConstantBackoffOptions, + type LinearBackoffOptions, + type ExponentialBackoffOptions, + ConstantBackoff, + LinearBackoff, + ExponentialBackoff, + ModelRetryStrategy, + DefaultModelRetryStrategy, + type DefaultModelRetryStrategyOptions, + type RetryStrategy, + type RetryDecision, +} from './retry/index.js' + +// Conversation Manager +export { + ConversationManager, + type ProactiveCompressionConfig, + type ConversationManagerReduceOptions, + type ConversationManagerOptions, +} from './conversation-manager/conversation-manager.js' +export { NullConversationManager } from './conversation-manager/null-conversation-manager.js' +export { + SlidingWindowConversationManager, + type SlidingWindowConversationManagerConfig, +} from './conversation-manager/sliding-window-conversation-manager.js' +export { + SummarizingConversationManager, + type SummarizingConversationManagerConfig, +} from './conversation-manager/summarizing-conversation-manager.js' + +// Logging +export { configureLogging } from './logging/logger.js' +export type { Logger } from './logging/types.js' + +// MCP Client types and implementations +export { + type McpClientOptions, + type McpClientConfig, + type McpClientCredentials, + type McpTransport, + type McpCallToolOptions, + type TasksConfig, + type McpConnectionState, + McpClient, +} from './mcp.js' +export { type McpServerConfig } from './mcp-config.js' +export type { ElicitationCallback, ElicitationContext } from './types/elicitation.js' + +// Session management +export { SessionManager } from './session/session-manager.js' +export type { + SessionManagerConfig, + SaveLatestStrategy, + MultiAgentSaveLatestStrategy, +} from './session/session-manager.js' +export type { SnapshotManifest, SnapshotTriggerCallback, SnapshotTriggerParams } from './session/types.js' +export type { SessionStorage, SnapshotStorage, SnapshotLocation } from './session/storage.js' +export { FileStorage } from './session/file-storage.js' + +// Local Traces +export { AgentTrace } from './telemetry/tracer.js' + +// Local Metrics +export { AgentMetrics } from './telemetry/meter.js' + +// Sandbox +export { Sandbox, type ExecuteOptions } from './sandbox/base.js' +export { PosixShellSandbox } from './sandbox/posix-shell.js' +export type { StreamType, StreamChunk, FileInfo, OutputFile, ExecutionResult } from './sandbox/types.js' + +// Multi-agent orchestration +export { Graph } from './multiagent/index.js' +export { Swarm } from './multiagent/index.js' diff --git a/strands-ts/src/interrupt.ts b/strands-ts/src/interrupt.ts new file mode 100644 index 0000000000..116254abc3 --- /dev/null +++ b/strands-ts/src/interrupt.ts @@ -0,0 +1,454 @@ +/** + * Human-in-the-loop interrupt system for agent workflows. + * + * Interrupt Flow: + * 1. Hook or tool calls `event.interrupt()` or `context.interrupt()` + * 2. If resuming (response exists), the response is returned + * 3. Otherwise, agent execution halts with `stopReason: 'interrupt'` + * 4. User resumes by invoking agent with `interruptResponse` content blocks + * 5. On resume, `interrupt()` returns the user's response + */ + +import { InterruptResponseContent, type InterruptResponseContentData, type InterruptParams } from './types/interrupt.js' +import type { JSONValue } from './types/json.js' +import type { LocalAgent } from './types/agent.js' +import { Message, ToolResultBlock, type MessageData, type ToolResultBlockData } from './types/messages.js' + +/** + * Origin of an interrupt: + * - `'tool'` — raised by a tool callback via `ToolContext.interrupt()`. + * - `'hook'` — raised by an agent-level hook (e.g. `BeforeToolCallEvent.interrupt()`). + * - `'multiagent-hook'` — raised by a multi-agent hook (e.g. `BeforeNodeCallEvent.interrupt()`). + */ +export type InterruptSource = 'tool' | 'hook' | 'multiagent-hook' + +/** + * Represents an interrupt that can pause agent execution for human-in-the-loop workflows. + */ +export class Interrupt { + /** + * Unique identifier for this interrupt. + */ + readonly id: string + + /** + * User-defined name for the interrupt. + */ + readonly name: string + + /** + * User-provided reason for raising the interrupt. + */ + readonly reason?: JSONValue + + /** + * Human response provided when resuming the agent after an interrupt. + */ + response?: JSONValue + + /** + * Where this interrupt was raised from — a tool callback, an agent-level hook, or + * a multi-agent orchestrator hook. Always populated. When deserializing a snapshot + * produced by an older SDK that did not record this field, defaults to `'hook'`. + */ + readonly source: InterruptSource + + constructor(data: { id: string; name: string; reason?: JSONValue; response?: JSONValue; source?: InterruptSource }) { + this.id = data.id + this.name = data.name + if (data.reason !== undefined) { + this.reason = data.reason + } + if (data.response !== undefined) { + this.response = data.response + } + // Default for legacy snapshots that predate the `source` field; current code + // paths always supply a value explicitly. + this.source = data.source ?? 'hook' + } + + /** + * Serializes the interrupt to a JSON-compatible object. + */ + toJSON(): { id: string; name: string; reason?: JSONValue; response?: JSONValue; source: InterruptSource } { + return { + id: this.id, + name: this.name, + ...(this.reason !== undefined && { reason: this.reason }), + ...(this.response !== undefined && { response: this.response }), + source: this.source, + } + } + + /** + * Creates an Interrupt instance from a JSON object. + * + * @param data - JSON data to deserialize + * @returns Interrupt instance + */ + static fromJSON(data: { + id: string + name: string + reason?: JSONValue + response?: JSONValue + source?: InterruptSource + }): Interrupt { + return new Interrupt(data) + } +} + +/** + * Error thrown when human input is required to continue agent execution. + * Caught by the agent loop to trigger an interrupt stop. + */ +export class InterruptError extends Error { + /** + * The interrupts that caused this error. + */ + readonly interrupts: Interrupt[] + + constructor(interrupt: Interrupt | Interrupt[]) { + const all = Array.isArray(interrupt) ? interrupt : [interrupt] + const message = + all.length === 1 + ? `Interrupt raised: ${all[0]!.name}` + : `${all.length} interrupts raised: ${all.map((i) => i.name).join(', ')}` + super(message) + this.name = 'InterruptError' + this.interrupts = all + } +} + +/** + * Data format for serialized interrupt state. + */ +export interface InterruptStateData { + /** + * Map of interrupt IDs to interrupt data. + */ + interrupts: Record + + /** + * Resume responses that were provided when resuming from an interrupt. + */ + resumeResponses?: InterruptResponseContentData[] | undefined + + /** + * Whether the agent is in an interrupted state. + */ + activated: boolean + + /** + * Pending tool execution state for resume after interrupt. + */ + pendingToolExecution?: PendingToolExecution | undefined +} + +/** + * Pending tool execution state stored when an interrupt occurs mid-execution. + * Contains all data needed to resume tool execution without re-calling the model. + */ +export interface PendingToolExecution { + /** + * The assistant message containing tool use blocks, serialized as MessageData. + */ + assistantMessageData: MessageData + + /** + * Tool results that were completed before the interrupt. + * Maps toolUseId to serialized ToolResultBlock data. + */ + completedToolResults: Record +} + +/** + * Tracks the state of interrupt events raised during agent execution. + * + * Interrupt state is cleared after resuming. + */ +export class InterruptState implements InterruptStateData { + /** Record of interrupt IDs to Interrupt instances. */ + interrupts: Record + + /** Resume responses provided when resuming from an interrupt. */ + resumeResponses?: InterruptResponseContent[] | undefined + + /** Whether the agent is in an interrupted state. */ + activated: boolean + + /** Pending tool execution state for resume. */ + pendingToolExecution?: PendingToolExecution | undefined + + constructor() { + this.interrupts = {} + this.resumeResponses = undefined + this.activated = false + this.pendingToolExecution = undefined + } + + /** + * Gets the pending tool execution state with reconstructed Message and ToolResultBlock objects. + * Returns undefined if there is no pending execution. + */ + getPendingExecution(): { assistantMessage: Message; completedToolResults: Map } | undefined { + if (!this.pendingToolExecution) { + return undefined + } + + const assistantMessage = Message.fromMessageData(this.pendingToolExecution.assistantMessageData) + + const completedToolResults = new Map() + for (const [toolUseId, resultData] of Object.entries(this.pendingToolExecution.completedToolResults)) { + completedToolResults.set(toolUseId, ToolResultBlock.fromJSON(resultData)) + } + + return { assistantMessage, completedToolResults } + } + + /** + * Sets the pending tool execution state. + */ + setPendingToolExecution(pending: PendingToolExecution): void { + this.pendingToolExecution = pending + } + + /** + * Clears the pending tool execution state. + */ + clearPendingToolExecution(): void { + this.pendingToolExecution = undefined + } + + /** + * Returns the list of interrupts as an array. + */ + getInterruptsList(): Interrupt[] { + return Object.values(this.interrupts) + } + + /** + * Returns all interrupts that have no response (i.e., were raised but not yet answered). + */ + getUnansweredInterrupts(): Interrupt[] { + return Object.values(this.interrupts).filter((interrupt) => interrupt.response === undefined) + } + + /** + * Returns the first interrupt that has no response (i.e., was raised but not yet answered). + */ + getUnansweredInterrupt(): Interrupt | undefined { + for (const interrupt of Object.values(this.interrupts)) { + if (interrupt.response === undefined) { + return interrupt + } + } + return undefined + } + + /** + * Activates the interrupt state. + */ + activate(): void { + this.activated = true + } + + /** + * Deactivates the interrupt state and clears all interrupts and context. + */ + deactivate(): void { + this.interrupts = {} + this.resumeResponses = undefined + this.activated = false + this.pendingToolExecution = undefined + } + + /** + * Configures the interrupt state for resuming from an interrupt. + * Populates interrupt responses from the provided content blocks. + * + * @param responses - Array of interrupt response content blocks + * @throws Error if an interrupt ID is not found + */ + resume(responses: InterruptResponseContent[]): void { + if (!this.activated) { + return + } + + for (const content of responses) { + const interruptId = content.interruptResponse.interruptId + const response = content.interruptResponse.response + + const interrupt = this.interrupts[interruptId] + if (!interrupt) { + throw new Error(`interrupt_id=<${interruptId}> | no interrupt found`) + } + + interrupt.response = response + } + + this.resumeResponses = responses + } + + /** + * Gets or creates an interrupt with the given ID. + * If the interrupt already exists, returns it (potentially with a response). + * If a preemptive response is provided and the interrupt is new, the response + * is stored on the interrupt so it returns immediately without halting execution. + * + * @param id - Unique identifier for the interrupt + * @param name - User-defined name for the interrupt + * @param reason - Optional reason for the interrupt + * @param response - Optional preemptive response to skip the interrupt + * @param source - Where the interrupt was raised from (tool or hook callback) + * @returns The interrupt (may have a response if resuming or preemptive) + */ + getOrCreateInterrupt( + id: string, + name: string, + reason?: JSONValue, + response?: JSONValue, + source?: InterruptSource + ): Interrupt { + const existing = this.interrupts[id] + if (existing) { + return existing + } + + const interrupt = new Interrupt({ + id, + name, + ...(reason !== undefined && { reason }), + ...(response !== undefined && { response }), + ...(source !== undefined && { source }), + }) + this.interrupts[id] = interrupt + return interrupt + } + + /** + * Serializes the interrupt state to a JSON-compatible object. + */ + toJSON(): InterruptStateData { + const interrupts: Record = {} + for (const [id, interrupt] of Object.entries(this.interrupts)) { + interrupts[id] = interrupt.toJSON() + } + + return { + interrupts, + ...(this.resumeResponses && { resumeResponses: this.resumeResponses }), + activated: this.activated, + ...(this.pendingToolExecution && { pendingToolExecution: this.pendingToolExecution }), + } + } + + /** + * Creates an InterruptState instance from a JSON object. + * + * @param data - JSON data to deserialize + * @returns InterruptState instance + */ + static fromJSON(data: InterruptStateData): InterruptState { + const state = new InterruptState() + state.activated = data.activated + + for (const [id, interruptData] of Object.entries(data.interrupts)) { + state.interrupts[id] = Interrupt.fromJSON(interruptData) + } + + if (data.resumeResponses) { + state.resumeResponses = data.resumeResponses.map((r) => InterruptResponseContent.fromJSON(r)) + } + + if (data.pendingToolExecution) { + state.pendingToolExecution = data.pendingToolExecution + } + + return state + } +} + +/** + * Interface for objects that support human-in-the-loop interrupts. + * Implemented by hook events and tool contexts that can pause agent execution. + */ +export interface Interruptible { + interrupt(params: InterruptParams): T +} + +/** + * Shared interrupt logic that accesses the agent's interrupt state to register or resume an interrupt. + * + * @param agent - The agent whose interrupt state to access + * @param interruptId - Unique identifier for this interrupt instance + * @param params - Interrupt parameters including name and optional reason + * @param source - Where the interrupt was raised from (tool callback vs hook callback) + * @returns The user's response when resuming from an interrupt + * @throws InterruptError when no response is available (first invocation) + * + * @internal + */ +export function interruptFromAgent( + agent: LocalAgent, + interruptId: string, + params: InterruptParams, + source: InterruptSource +): T { + const interruptState = (agent as unknown as { _interruptState?: InterruptState })._interruptState + if (!interruptState) { + throw new Error('Interrupt state not available') + } + + const interrupt = interruptState.getOrCreateInterrupt( + interruptId, + params.name, + params.reason, + params.response, + source + ) + + if (interrupt.response !== undefined) { + return interrupt.response as T + } + + throw new InterruptError(interrupt) +} + +/** + * Interrupt-or-resume helper for multi-agent hooks where interrupts live on a per-node + * `Interrupt[]` list rather than on an agent's `InterruptState`. Mirrors the + * {@link interruptFromAgent} contract: returns the response if the interrupt already + * has one (resume path), otherwise records a new interrupt and throws `InterruptError`. + * + * @internal + */ +export function interruptFromMultiAgentNode( + interrupts: Interrupt[], + interruptId: string, + params: InterruptParams, + source: InterruptSource +): T { + const existing = interrupts.find((i) => i.id === interruptId) + if (existing?.response !== undefined) { + return existing.response as T + } + + const interrupt = + existing ?? + new Interrupt({ + id: interruptId, + name: params.name, + ...(params.reason !== undefined && { reason: params.reason }), + ...(params.response !== undefined && { response: params.response }), + source, + }) + + if (!existing) { + interrupts.push(interrupt) + if (interrupt.response !== undefined) { + return interrupt.response as T + } + } + + throw new InterruptError(interrupt) +} diff --git a/strands-ts/src/interventions/__tests__/handler.test.ts b/strands-ts/src/interventions/__tests__/handler.test.ts new file mode 100644 index 0000000000..9185069d20 --- /dev/null +++ b/strands-ts/src/interventions/__tests__/handler.test.ts @@ -0,0 +1,50 @@ +import { describe, expect, it } from 'vitest' +import { InterventionHandler } from '../handler.js' +import { Agent } from '../../agent/agent.js' +import { BeforeToolCallEvent, AfterModelCallEvent } from '../../hooks/events.js' +import type { InterventionAction } from '../actions.js' + +class NoOpHandler extends InterventionHandler { + readonly name = 'no-op' +} + +class ToolOnlyHandler extends InterventionHandler { + readonly name = 'tool-only' + + override beforeToolCall(): InterventionAction { + return { type: 'deny', reason: 'blocked' } + } +} + +describe('InterventionHandler', () => { + const agent = new Agent() + const toolUse = { name: 'test', toolUseId: 'id', input: {} } + + it('default methods return proceed', () => { + const handler = new NoOpHandler() + + expect( + handler.beforeToolCall(new BeforeToolCallEvent({ agent, toolUse, tool: undefined, invocationState: {} })) + ).toEqual({ + type: 'proceed', + }) + expect( + handler.afterModelCall( + new AfterModelCallEvent({ agent, model: {} as never, invocationState: {}, attemptCount: 0 }) + ) + ).toEqual({ + type: 'proceed', + }) + }) + + it('override detection works via prototype comparison', () => { + const noOp = new NoOpHandler() + const toolOnly = new ToolOnlyHandler() + + expect(noOp.beforeToolCall).toBe(InterventionHandler.prototype.beforeToolCall) + expect(noOp.afterModelCall).toBe(InterventionHandler.prototype.afterModelCall) + + expect(toolOnly.beforeToolCall).not.toBe(InterventionHandler.prototype.beforeToolCall) + expect(toolOnly.afterModelCall).toBe(InterventionHandler.prototype.afterModelCall) + }) +}) diff --git a/strands-ts/src/interventions/__tests__/registry.test.ts b/strands-ts/src/interventions/__tests__/registry.test.ts new file mode 100644 index 0000000000..0aac879340 --- /dev/null +++ b/strands-ts/src/interventions/__tests__/registry.test.ts @@ -0,0 +1,873 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { InterventionRegistry } from '../registry.js' +import { InterventionHandler } from '../handler.js' +import { HookRegistryImplementation } from '../../hooks/registry.js' +import { Agent } from '../../agent/agent.js' +import { + BeforeInvocationEvent, + BeforeToolCallEvent, + BeforeModelCallEvent, + AfterModelCallEvent, +} from '../../hooks/events.js' +import { Message, TextBlock } from '../../types/messages.js' +import { deny } from '../actions.js' +import type { InterventionAction, Guide, Transform, Proceed } from '../actions.js' +import { Interrupt, InterruptState } from '../../interrupt.js' + +class DenyHandler extends InterventionHandler { + readonly name = 'deny-handler' + + override beforeToolCall(): InterventionAction { + return { type: 'deny', reason: 'not authorized' } + } +} + +class GuideHandler extends InterventionHandler { + readonly name = 'guide-handler' + + override beforeToolCall(): InterventionAction { + return { type: 'guide', feedback: 'add more context' } + } +} + +class ConfirmHandler extends InterventionHandler { + readonly name = 'confirm-handler' + + override beforeToolCall(): InterventionAction { + return { type: 'confirm', prompt: 'approve this action?' } + } +} + +class ProceedHandler extends InterventionHandler { + readonly name = 'proceed-handler' + + override beforeToolCall(): InterventionAction { + return { type: 'proceed', reason: 'all good' } + } +} + +class ThrowingHandler extends InterventionHandler { + readonly name = 'throwing-handler' + override readonly onError = 'throw' as const + + override beforeToolCall(): InterventionAction { + throw new Error('handler crashed') + } +} + +class ThrowingProceedHandler extends InterventionHandler { + readonly name = 'throwing-proceed' + override readonly onError = 'proceed' as const + + override beforeToolCall(): InterventionAction { + throw new Error('handler crashed') + } +} + +class ThrowingDenyHandler extends InterventionHandler { + readonly name = 'throwing-deny' + override readonly onError = 'deny' as const + + override beforeToolCall(): InterventionAction { + throw new Error('handler crashed') + } +} + +class AsyncDenyHandler extends InterventionHandler { + readonly name = 'async-deny' + + override async beforeToolCall(): Promise { + return { type: 'deny', reason: 'async denial' } + } +} + +class ModelGuideHandler extends InterventionHandler { + readonly name = 'model-guide' + + override afterModelCall(): Proceed | Guide | Transform { + return { type: 'guide', feedback: 'be more specific' } + } +} + +describe('InterventionRegistry', () => { + let hookRegistry: HookRegistryImplementation + let agent: Agent + const toolUse = { name: 'testTool', toolUseId: 'id-1', input: {} } + + beforeEach(() => { + hookRegistry = new HookRegistryImplementation() + agent = new Agent() + }) + + function makeBeforeInvocationEvent() { + return new BeforeInvocationEvent({ agent, invocationState: {} }) + } + + function makeBeforeToolCallEvent() { + return new BeforeToolCallEvent({ agent, toolUse, tool: undefined, invocationState: {} }) + } + + function makeBeforeModelCallEvent() { + return new BeforeModelCallEvent({ agent, model: {} as never, invocationState: {} }) + } + + function makeAfterModelCallEvent() { + return new AfterModelCallEvent({ + agent, + model: {} as never, + invocationState: {}, + attemptCount: 0, + stopData: { + message: new Message({ role: 'assistant', content: [new TextBlock('response')] }), + stopReason: 'endTurn', + }, + }) + } + + describe('constructor', () => { + it('rejects duplicate handler names', () => { + expect(() => new InterventionRegistry([new DenyHandler(), new DenyHandler()], hookRegistry)).toThrow( + "Duplicate intervention handler name: 'deny-handler'" + ) + }) + + it('accepts handlers with unique names', () => { + // No throw means success + new InterventionRegistry([new DenyHandler(), new GuideHandler()], hookRegistry) + }) + }) + + describe('hook registration', () => { + it('only registers hooks for overridden methods', async () => { + new InterventionRegistry([new DenyHandler()], hookRegistry) + + const beforeToolEvent = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(beforeToolEvent) + expect(beforeToolEvent.cancel).toBe('DENIED: not authorized') + + // afterModelCall should not be registered — no handler overrides it + const afterModelEvent = makeAfterModelCallEvent() + await hookRegistry.invokeCallbacks(afterModelEvent) + expect(afterModelEvent.retry).toBeUndefined() + }) + }) + + describe('dispatch ordering', () => { + it('calls handlers in registration order', async () => { + const callOrder: string[] = [] + + class First extends InterventionHandler { + readonly name = 'first' + override beforeToolCall(): InterventionAction { + callOrder.push('first') + return { type: 'proceed' } + } + } + class Second extends InterventionHandler { + readonly name = 'second' + override beforeToolCall(): InterventionAction { + callOrder.push('second') + return { type: 'proceed' } + } + } + + new InterventionRegistry([new First(), new Second()], hookRegistry) + + await hookRegistry.invokeCallbacks(makeBeforeToolCallEvent()) + expect(callOrder).toEqual(['first', 'second']) + }) + + it('skips handlers that do not override the method', async () => { + const callOrder: string[] = [] + + class ToolHandler extends InterventionHandler { + readonly name = 'tool' + override beforeToolCall(): InterventionAction { + callOrder.push('tool') + return { type: 'proceed' } + } + } + class ModelHandler extends InterventionHandler { + readonly name = 'model' + override afterModelCall(): Proceed | Guide | Transform { + callOrder.push('model') + return { type: 'proceed' } + } + } + + new InterventionRegistry([new ToolHandler(), new ModelHandler()], hookRegistry) + + await hookRegistry.invokeCallbacks(makeBeforeToolCallEvent()) + expect(callOrder).toEqual(['tool']) + }) + }) + + describe('deny', () => { + it('sets cancel on BeforeToolCallEvent', async () => { + new InterventionRegistry([new DenyHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + + expect(event.cancel).toBe('DENIED: not authorized') + }) + + it('short-circuits — later handlers do not run', async () => { + const laterCalled = vi.fn() + + class LaterHandler extends InterventionHandler { + readonly name = 'later' + override beforeToolCall(): InterventionAction { + laterCalled() + return { type: 'proceed' } + } + } + + new InterventionRegistry([new DenyHandler(), new LaterHandler()], hookRegistry) + + await hookRegistry.invokeCallbacks(makeBeforeToolCallEvent()) + expect(laterCalled).not.toHaveBeenCalled() + }) + + it('sets cancel on BeforeInvocationEvent', async () => { + class InvocationDeny extends InterventionHandler { + readonly name = 'invocation-deny' + override beforeInvocation() { + return deny('unauthorized user') + } + } + + new InterventionRegistry([new InvocationDeny()], hookRegistry) + + const event = makeBeforeInvocationEvent() + await hookRegistry.invokeCallbacks(event) + expect(event.cancel).toBe('DENIED: unauthorized user') + }) + + it('sets cancel on BeforeModelCallEvent', async () => { + class ModelDeny extends InterventionHandler { + readonly name = 'model-deny' + override beforeModelCall() { + return deny('prompt injection detected') + } + } + + new InterventionRegistry([new ModelDeny()], hookRegistry) + + const event = makeBeforeModelCallEvent() + await hookRegistry.invokeCallbacks(event) + expect(event.cancel).toBe('DENIED: prompt injection detected') + }) + }) + + describe('guide', () => { + it('sets cancel with guidance on BeforeToolCallEvent', async () => { + new InterventionRegistry([new GuideHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + expect(event.cancel).toBe('GUIDANCE: [guide-handler] add more context') + }) + + it('accumulates feedback from multiple handlers', async () => { + class SecondGuide extends InterventionHandler { + readonly name = 'second-guide' + override beforeToolCall(): InterventionAction { + return { type: 'guide', feedback: 'also check permissions' } + } + } + + new InterventionRegistry([new GuideHandler(), new SecondGuide()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + expect(event.cancel).toBe('GUIDANCE: [guide-handler] add more context\n[second-guide] also check permissions') + }) + + it('sets retry=true and injects guidance message on AfterModelCallEvent', async () => { + new InterventionRegistry([new ModelGuideHandler()], hookRegistry) + + const event = makeAfterModelCallEvent() + const messageCountBefore = event.agent.messages.length + await hookRegistry.invokeCallbacks(event) + + expect(event.retry).toBe(true) + expect(event.agent.messages).toHaveLength(messageCountBefore + 1) + const guidanceMessage = event.agent.messages[event.agent.messages.length - 1]! + expect(guidanceMessage.role).toBe('user') + expect(guidanceMessage.content[0]).toMatchObject({ type: 'textBlock', text: '[model-guide] be more specific' }) + }) + }) + + describe('confirm', () => { + it('pauses agent when no response is provided', async () => { + new InterventionRegistry([new ConfirmHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await expect(hookRegistry.invokeCallbacks(event)).rejects.toThrow('Interrupt raised') + }) + + it('short-circuits — later handlers do not run', async () => { + const laterCalled = vi.fn() + + class LaterHandler extends InterventionHandler { + readonly name = 'later' + override beforeToolCall(): InterventionAction { + laterCalled() + return { type: 'proceed' } + } + } + + new InterventionRegistry([new ConfirmHandler(), new LaterHandler()], hookRegistry) + + await expect(hookRegistry.invokeCallbacks(makeBeforeToolCallEvent())).rejects.toThrow() + expect(laterCalled).not.toHaveBeenCalled() + }) + + function preloadInterruptResponse(handlerName: string, response: unknown) { + const interruptId = `hook:beforeToolCall:${toolUse.toolUseId}:${handlerName}` + const interruptState = (agent as unknown as { _interruptState: InterruptState })._interruptState + interruptState.interrupts[interruptId] = new Interrupt({ + id: interruptId, + name: handlerName, + response: response as never, + source: 'hook', + }) + } + + describe('approve/deny on resume', () => { + const DENIED = 'CONFIRMATION_FAILED: approve this action?' + + it.each([ + [true, false], + ['yes', false], + ['y', false], + ['Y', false], + ['YES', false], + [' yes ', false], + ['no', DENIED], + [false, DENIED], + [null, DENIED], + ['', DENIED], + ])('response %j → cancel=%j', async (response, expectedCancel) => { + preloadInterruptResponse('confirm-handler', response) + new InterventionRegistry([new ConfirmHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + expect(event.cancel).toBe(expectedCancel) + }) + }) + + it('uses custom evaluate when provided', async () => { + class CustomApprovalHandler extends InterventionHandler { + readonly name = 'custom-approval' + override beforeToolCall(): InterventionAction { + return { + type: 'confirm', + prompt: 'approve?', + evaluate: (response) => response === 'custom-yes', + } + } + } + + // 'yes' would pass default evaluate but fails custom + preloadInterruptResponse('custom-approval', 'yes') + + new InterventionRegistry([new CustomApprovalHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + expect(event.cancel).toBe('CONFIRMATION_FAILED: approve?') + }) + + it('custom evaluate approves when its condition is met', async () => { + class CustomApprovalHandler extends InterventionHandler { + readonly name = 'custom-approval' + override beforeToolCall(): InterventionAction { + return { + type: 'confirm', + prompt: 'approve?', + evaluate: (response) => response === 'custom-yes', + } + } + } + + preloadInterruptResponse('custom-approval', 'custom-yes') + + new InterventionRegistry([new CustomApprovalHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + expect(event.cancel).toBe(false) + }) + + it('approved confirm does not short-circuit later handlers', async () => { + preloadInterruptResponse('confirm-handler', 'yes') + + const laterCalled = vi.fn() + + class LaterHandler extends InterventionHandler { + readonly name = 'later' + override beforeToolCall(): InterventionAction { + laterCalled() + return { type: 'proceed' } + } + } + + new InterventionRegistry([new ConfirmHandler(), new LaterHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + expect(event.cancel).toBe(false) + expect(laterCalled).toHaveBeenCalled() + }) + + it('denied confirm short-circuits later handlers', async () => { + preloadInterruptResponse('confirm-handler', 'no') + const laterCalled = vi.fn() + + class LaterHandler extends InterventionHandler { + readonly name = 'later' + override beforeToolCall(): InterventionAction { + laterCalled() + return { type: 'proceed' } + } + } + + new InterventionRegistry([new ConfirmHandler(), new LaterHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + expect(event.cancel).toBe('CONFIRMATION_FAILED: approve this action?') + expect(laterCalled).not.toHaveBeenCalled() + }) + + describe('preemptive response (inline mode)', () => { + it('approves when response is an approved value', async () => { + class InlineConfirmHandler extends InterventionHandler { + readonly name = 'inline-confirm' + override beforeToolCall(): InterventionAction { + return { type: 'confirm', prompt: 'approve?', response: 'yes' } + } + } + + new InterventionRegistry([new InlineConfirmHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + expect(event.cancel).toBe(false) + }) + + it('denies when response is a non-approved value', async () => { + class InlineConfirmHandler extends InterventionHandler { + readonly name = 'inline-confirm' + override beforeToolCall(): InterventionAction { + return { type: 'confirm', prompt: 'approve?', response: 'no' } + } + } + + new InterventionRegistry([new InlineConfirmHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + expect(event.cancel).toBe('CONFIRMATION_FAILED: approve?') + }) + + it('uses custom evaluate with preemptive response', async () => { + class OtpHandler extends InterventionHandler { + readonly name = 'otp-handler' + override beforeToolCall(): InterventionAction { + return { + type: 'confirm', + prompt: 'Enter OTP:', + response: '123456', + evaluate: (r) => r === '123456', + } + } + } + + new InterventionRegistry([new OtpHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + expect(event.cancel).toBe(false) + }) + + it('passes response as preemptive value so agent never pauses', async () => { + class InlineConfirmHandler extends InterventionHandler { + readonly name = 'inline-confirm' + override beforeToolCall(): InterventionAction { + return { type: 'confirm', prompt: 'approve?', response: 'yes' } + } + } + + new InterventionRegistry([new InlineConfirmHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + const interruptSpy = vi.spyOn(event, 'interrupt') + await hookRegistry.invokeCallbacks(event) + expect(interruptSpy).toHaveBeenCalledWith({ name: 'inline-confirm', reason: 'approve?', response: 'yes' }) + }) + + it('denies when response is falsy but defined (false)', async () => { + class InlineConfirmHandler extends InterventionHandler { + readonly name = 'inline-confirm' + override beforeToolCall(): InterventionAction { + return { type: 'confirm', prompt: 'approve?', response: false } + } + } + + new InterventionRegistry([new InlineConfirmHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + expect(event.cancel).toBe('CONFIRMATION_FAILED: approve?') + }) + }) + + it.each(['proceed', 'deny'] as const)( + 'InterruptError always propagates regardless of onError=%s', + async (onError) => { + class ConfirmWithOnError extends InterventionHandler { + readonly name = 'confirm-onerror' + override readonly onError = onError + override beforeToolCall(): InterventionAction { + return { type: 'confirm', prompt: 'approve?' } + } + } + + new InterventionRegistry([new ConfirmWithOnError()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await expect(hookRegistry.invokeCallbacks(event)).rejects.toThrow('Interrupt raised') + } + ) + }) + + describe('transform', () => { + it('calls the apply function with the event', async () => { + const applyFn = vi.fn() + + class TransformHandler extends InterventionHandler { + readonly name = 'transform-handler' + override beforeToolCall(): InterventionAction { + return { type: 'transform', apply: applyFn, reason: 'sanitized input' } + } + } + + new InterventionRegistry([new TransformHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + expect(applyFn).toHaveBeenCalledWith(event) + }) + + it('later handlers see the transformed state', async () => { + const observed: string[] = [] + + class Transformer extends InterventionHandler { + readonly name = 'transformer' + override beforeToolCall(): InterventionAction { + return { + type: 'transform', + apply: (e) => { + ;(e as BeforeToolCallEvent).cancel = 'transformed' + }, + } + } + } + + class Observer extends InterventionHandler { + readonly name = 'observer' + override beforeToolCall(event: BeforeToolCallEvent): InterventionAction { + observed.push(String(event.cancel)) + return { type: 'proceed' } + } + } + + new InterventionRegistry([new Transformer(), new Observer()], hookRegistry) + + await hookRegistry.invokeCallbacks(makeBeforeToolCallEvent()) + expect(observed).toEqual(['transformed']) + }) + + it('works on AfterModelCallEvent', async () => { + const applyFn = vi.fn() + + class ModelTransform extends InterventionHandler { + readonly name = 'model-transform' + override afterModelCall(): Proceed | Guide | Transform { + return { type: 'transform', apply: applyFn, reason: 'redacted output' } + } + } + + new InterventionRegistry([new ModelTransform()], hookRegistry) + + const event = makeAfterModelCallEvent() + await hookRegistry.invokeCallbacks(event) + expect(applyFn).toHaveBeenCalledWith(event) + }) + + it('is logged in the audit trail', async () => { + class TransformHandler extends InterventionHandler { + readonly name = 'transform-handler' + override beforeToolCall(): InterventionAction { + return { type: 'transform', apply: () => {}, reason: 'sanitized' } + } + } + + new InterventionRegistry([new TransformHandler()], hookRegistry) + + await hookRegistry.invokeCallbacks(makeBeforeToolCallEvent()) + // Transform was applied (verified by the apply fn mock tests above) + }) + }) + + describe('proceed', () => { + it('does not mutate the event', async () => { + new InterventionRegistry([new ProceedHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + expect(event.cancel).toBe(false) + }) + }) + + describe('error handling', () => { + it('onError=throw (default) rethrows the error', async () => { + new InterventionRegistry([new ThrowingHandler(), new ProceedHandler()], hookRegistry) + + await expect(hookRegistry.invokeCallbacks(makeBeforeToolCallEvent())).rejects.toThrow('handler crashed') + }) + + it('onError=proceed skips the handler and continues to next', async () => { + new InterventionRegistry([new ThrowingProceedHandler(), new ProceedHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + expect(event.cancel).toBe(false) + }) + + it('onError=deny logs the error and applies deny', async () => { + const laterCalled = vi.fn() + + class LaterHandler extends InterventionHandler { + readonly name = 'later' + override beforeToolCall(): InterventionAction { + laterCalled() + return { type: 'proceed' } + } + } + + new InterventionRegistry([new ThrowingDenyHandler(), new LaterHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + + expect(event.cancel).toBe('DENIED: Handler threw: handler crashed') + expect(laterCalled).not.toHaveBeenCalled() + }) + }) + + describe('async handlers', () => { + it('awaits async handler results', async () => { + new InterventionRegistry([new AsyncDenyHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + expect(event.cancel).toBe('DENIED: async denial') + }) + }) + + describe('conflict resolution', () => { + it('deny wins over guide', async () => { + new InterventionRegistry([new GuideHandler(), new DenyHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + + expect(event.cancel).toBe('DENIED: not authorized') + }) + + it('deny short-circuits before guide can accumulate', async () => { + new InterventionRegistry([new DenyHandler(), new GuideHandler()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + + expect(event.cancel).toBe('DENIED: not authorized') + }) + + it('confirm short-circuits before guide can accumulate', async () => { + new InterventionRegistry([new ConfirmHandler(), new GuideHandler()], hookRegistry) + + await expect(hookRegistry.invokeCallbacks(makeBeforeToolCallEvent())).rejects.toThrow('Interrupt raised') + }) + }) + + describe('agent integration', () => { + it('deny on beforeToolCall prevents tool execution', async () => { + const { MockMessageModel } = await import('../../__fixtures__/mock-message-model.js') + const { createMockTool } = await import('../../__fixtures__/tool-helpers.js') + + let toolExecuted = false + const tool = createMockTool('blockedTool', () => { + toolExecuted = true + return 'should not reach here' + }) + + class BlockAllTools extends InterventionHandler { + readonly name = 'block-all' + override beforeToolCall(): InterventionAction { + return { type: 'deny', reason: 'blocked by intervention' } + } + } + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'blockedTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ + model, + tools: [tool], + interventions: [new BlockAllTools()], + }) + + const result = await agent.invoke('Test') + + expect(result.stopReason).toBe('endTurn') + expect(toolExecuted).toBe(false) + }) + + it('interventions run before plugins (HookOrder.INTERVENTIONS < DEFAULT)', async () => { + const { MockMessageModel } = await import('../../__fixtures__/mock-message-model.js') + const { createMockTool } = await import('../../__fixtures__/tool-helpers.js') + + const callOrder: string[] = [] + + const tool = createMockTool('testTool', () => 'result') + + class OrderTracker extends InterventionHandler { + readonly name = 'order-tracker' + override beforeToolCall(): InterventionAction { + callOrder.push('intervention') + return { type: 'proceed' } + } + } + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'testTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ + model, + tools: [tool], + interventions: [new OrderTracker()], + }) + + agent.addHook(BeforeToolCallEvent, () => { + callOrder.push('plugin') + }) + + await agent.invoke('Test') + + // On Before*: plugins run first (DEFAULT:0), interventions last (INTERVENTIONS:90) + expect(callOrder[0]).toBe('plugin') + expect(callOrder[1]).toBe('intervention') + }) + }) + + describe('edge cases', () => { + it('guide on beforeModelCall injects a user message', async () => { + class ModelGuide extends InterventionHandler { + readonly name = 'model-guide' + override beforeModelCall(): Proceed | Guide | Transform { + return { type: 'guide', feedback: 'check your sources' } + } + } + + new InterventionRegistry([new ModelGuide()], hookRegistry) + + const event = makeBeforeModelCallEvent() + const messageCountBefore = event.agent.messages.length + await hookRegistry.invokeCallbacks(event) + + expect(event.cancel).toBe(false) + expect(event.agent.messages).toHaveLength(messageCountBefore + 1) + const injected = event.agent.messages[event.agent.messages.length - 1]! + expect(injected.role).toBe('user') + expect(injected.content[0]).toMatchObject({ type: 'textBlock', text: '[model-guide] check your sources' }) + }) + + it('transform apply() error is handled via onError policy', async () => { + class BadTransform extends InterventionHandler { + readonly name = 'bad-transform' + override readonly onError = 'proceed' as const + override beforeToolCall(): InterventionAction { + return { + type: 'transform', + apply: () => { + throw new Error('apply boom') + }, + } + } + } + + class AfterTransform extends InterventionHandler { + readonly name = 'after-transform' + override beforeToolCall(): InterventionAction { + return { type: 'proceed', reason: 'still running' } + } + } + + new InterventionRegistry([new BadTransform(), new AfterTransform()], hookRegistry) + + const event = makeBeforeToolCallEvent() + await hookRegistry.invokeCallbacks(event) + // onError=proceed means the error is swallowed and next handler runs + expect(event.cancel).toBe(false) + }) + + it('transform apply() error with onError=throw propagates', async () => { + class BadTransform extends InterventionHandler { + readonly name = 'bad-transform' + override readonly onError = 'throw' as const + override beforeToolCall(): InterventionAction { + return { + type: 'transform', + apply: () => { + throw new Error('apply boom') + }, + } + } + } + + new InterventionRegistry([new BadTransform()], hookRegistry) + + await expect(hookRegistry.invokeCallbacks(makeBeforeToolCallEvent())).rejects.toThrow('apply boom') + }) + + it('warns when action has no effect on event type', async () => { + const { logger } = await import('../../logging/logger.js') + const warnSpy = vi.spyOn(logger, 'warn') + + // Force a confirm return on beforeInvocation (which doesn't support it) + // via cast to test the runtime warning path + class InterruptOnInvocation extends InterventionHandler { + readonly name = 'confirm-invocation' + override beforeInvocation() { + // Force a confirm return via any cast to test the runtime warning + return { type: 'confirm', prompt: 'test' } as never + } + } + + new InterventionRegistry([new InterruptOnInvocation()], hookRegistry) + + await hookRegistry.invokeCallbacks(makeBeforeInvocationEvent()) + expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining('has no effect')) + + warnSpy.mockRestore() + }) + }) +}) diff --git a/strands-ts/src/interventions/actions.ts b/strands-ts/src/interventions/actions.ts new file mode 100644 index 0000000000..1031e903d3 --- /dev/null +++ b/strands-ts/src/interventions/actions.ts @@ -0,0 +1,208 @@ +import type { + BeforeInvocationEvent, + BeforeToolCallEvent, + AfterToolCallEvent, + BeforeModelCallEvent, + AfterModelCallEvent, +} from '../hooks/events.js' +import type { JSONValue } from '../types/json.js' + +const APPROVED_RESPONSES = new Set(['y', 'yes']) + +/** + * Default evaluate function for the confirm action. + * Accepts: true, 'y'/'yes' (case-insensitive, whitespace-trimmed). + * + * @param response - The human's response value to evaluate. + * @returns true if the response is considered an approval, false otherwise. + */ +export function defaultEvaluate(response: JSONValue): boolean { + if (response === true) return true + if (typeof response === 'string') return APPROVED_RESPONSES.has(response.toLowerCase().trim()) + return false +} + +export type LifecycleEvent = + | BeforeInvocationEvent + | BeforeToolCallEvent + | AfterToolCallEvent + | BeforeModelCallEvent + | AfterModelCallEvent + +/** + * Allow the operation to continue unchanged. + * + * @param reason - Optional metadata for debugging/logging. Not shown to the model. + * + * @example + * ```typescript + * return { type: 'proceed' } + * ``` + */ +export type Proceed = { type: 'proceed'; reason?: string } + +/** + * Block the operation. On Before* events, sets event.cancel with the reason text. + * The reason is shown to the model as the cancellation message. + * + * @param reason - Why the operation was blocked. Shown to the model. + * + * @example + * ```typescript + * override beforeToolCall(event: BeforeToolCallEvent): InterventionAction { + * if (!this.isAuthorized(event.agent.appState.get('user_id'), event.toolUse.name)) { + * return { type: 'deny', reason: 'User not authorized for this tool' } + * } + * return { type: 'proceed' } + * } + * ``` + */ +export type Deny = { type: 'deny'; reason: string } + +/** + * Provide feedback to steer behavior. On beforeToolCall/beforeInvocation, sets + * event.cancel so the model sees the feedback and adjusts. On beforeModelCall, + * injects feedback as a user message so the model sees it on this call. + * On afterModelCall, the response is discarded and the model retries with the + * feedback injected as a user message. + * + * @param feedback - The guidance text shown to the model. + * @param reason - Optional metadata for debugging/logging. Not shown to the model. + * + * @example + * ```typescript + * override afterModelCall(event: AfterModelCallEvent): InterventionAction { + * if (this.isTooVague(event.stopData?.message)) { + * return { type: 'guide', feedback: 'Be more specific in your response.' } + * } + * return { type: 'proceed' } + * } + * ``` + */ +export type Guide = { type: 'guide'; feedback: string; reason?: string } + +/** + * Request human approval before proceeding. Only supported on beforeToolCall. + * + * Two modes depending on whether `response` is provided: + * - With `response`: passed as a preemptive value to the interrupt system, agent + * never pauses. Handlers collect the response themselves (e.g. via readline). + * - Without `response`: breaks out of the agent loop to pause for external resume. + * + * The response is checked against `evaluate` (defaults to accepting `true` or + * `'y'`/`'yes'` case-insensitive). If denied, sets event.cancel. + * + * @example + * ```typescript + * // Inline mode (handler collected the response already) + * const answer = await rl.question(`${prompt} (y/n): `) + * return confirm(prompt, { response: answer }) + * + * // Stateless mode (interrupt/resume) + * return confirm(`Approve ${event.toolUse.name}?`) + * ``` + */ +export type Confirm = { + type: 'confirm' + prompt: string + reason?: string + response?: JSONValue + evaluate?: (response: JSONValue) => boolean +} + +/** + * Modify event content in-place. The `apply` function mutates the event before + * execution proceeds. Later handlers in the pipeline see the transformed content. + * + * The handler already has the typed event from its lifecycle method, so `apply` + * can close over it directly — no cast needed: + * + * @param apply - Function that mutates the event. Not shown to the model. + * @param reason - Optional metadata for debugging/logging. Not shown to the model. + * + * @example + * ```typescript + * override beforeToolCall(event: BeforeToolCallEvent): InterventionAction { + * const redacted = redactPII(event.toolUse.input) + * return { + * type: 'transform', + * apply: () => { event.toolUse.input = redacted }, + * reason: 'PII redacted from tool input', + * } + * } + * ``` + */ +export type Transform = { type: 'transform'; apply: (event: LifecycleEvent) => void; reason?: string } + +/** + * Union of all intervention actions a handler can return. + * + * | Action | beforeInvocation | beforeToolCall | beforeModelCall | afterToolCall | afterModelCall | + * |-----------|------------------|----------------|-----------------|---------------|----------------| + * | Proceed | — | — | — | — | — | + * | Deny | cancel | cancel | cancel | — | — | + * | Guide | cancel+ | cancel+ | inject | — | inject + retry | + * | Confirm | — | confirm | — | — | — | + * | Transform | apply | apply | apply | apply | apply | + * + * — = no-op (logged in audit trail, warns at runtime) + * cancel = sets event.cancel, short-circuits (remaining handlers skipped) + * cancel+ = sets event.cancel with accumulated feedback from all guiding handlers + * confirm = uses preemptive response or interrupt, checks with evaluate, sets cancel if denied + * inject = appends accumulated feedback as a user message so the model sees it on this call + * inject + retry = appends accumulated feedback and retries so the model sees guidance + * apply = calls action.apply(event) for in-place mutation, later handlers see the change + */ +export type InterventionAction = Proceed | Deny | Guide | Confirm | Transform + +/** + * Allow the operation to continue. + * @param options - Options: reason (debug metadata). + */ +export function proceed(options?: { reason?: string }): Proceed { + return { type: 'proceed', ...options } +} + +/** + * Block the operation. + * @param reason - Why the operation was blocked. Shown to the model. + */ +export function deny(reason: string): Deny { + return { type: 'deny', reason } +} + +/** + * Provide feedback to steer behavior. + * @param feedback - The guidance text shown to the model. + * @param options - Options: reason (debug metadata). + */ +export function guide(feedback: string, options?: { reason?: string }): Guide { + return { type: 'guide', feedback, ...options } +} + +/** + * Request human approval. + * @param prompt - Message shown to the human. Not shown to the model. + * @param options - Options: reason (debug metadata), evaluate (custom response + * validator, defaults to accepting true or y/yes case-insensitive), response + * (pre-collected value to skip pausing the agent). + */ +export function confirm( + prompt: string, + options?: { + reason?: string + response?: JSONValue + evaluate?: (response: JSONValue) => boolean + } +): Confirm { + return { type: 'confirm', prompt, evaluate: defaultEvaluate, ...options } +} + +/** + * Modify event content in-place. + * @param apply - Function that mutates the event. + * @param options - Options: reason (debug metadata). + */ +export function transform(apply: (event: LifecycleEvent) => void, options?: { reason?: string }): Transform { + return { type: 'transform', apply, ...options } +} diff --git a/strands-ts/src/interventions/handler.ts b/strands-ts/src/interventions/handler.ts new file mode 100644 index 0000000000..c5a4c88913 --- /dev/null +++ b/strands-ts/src/interventions/handler.ts @@ -0,0 +1,67 @@ +import type { + BeforeInvocationEvent, + BeforeToolCallEvent, + AfterToolCallEvent, + BeforeModelCallEvent, + AfterModelCallEvent, +} from '../hooks/events.js' +import type { Proceed, Deny, Guide, Confirm, Transform } from './actions.js' + +export type Awaitable = T | Promise + +/** + * What to do when a handler throws during evaluation. + * + * - `'throw'` — rethrow the error (default, safest: a broken policy check blocks execution) + * - `'proceed'` — log the error and continue as if the handler returned Proceed + * - `'deny'` — log the error and treat it as a Deny (fail-closed) + */ +export type OnError = 'throw' | 'proceed' | 'deny' + +/** + * Base class for intervention handlers. + * + * Handlers override the lifecycle methods they care about. Default implementations + * return Proceed. The framework detects which methods are overridden and only + * registers hook callbacks for those. + * + * @example + * ```typescript + * class CedarAuth extends InterventionHandler { + * readonly name = 'cedar-auth' + * + * override beforeToolCall(event: BeforeToolCallEvent): InterventionAction { + * if (!this.isAuthorized(event)) { + * return deny('User not authorized for this tool') + * } + * return proceed() + * } + * } + * ``` + */ +export abstract class InterventionHandler { + abstract readonly name: string + + /** What to do when this handler throws. Defaults to 'throw'. */ + readonly onError: OnError = 'throw' + + beforeInvocation(_event: BeforeInvocationEvent): Awaitable { + return { type: 'proceed' } + } + + beforeToolCall(_event: BeforeToolCallEvent): Awaitable { + return { type: 'proceed' } + } + + afterToolCall(_event: AfterToolCallEvent): Awaitable { + return { type: 'proceed' } + } + + beforeModelCall(_event: BeforeModelCallEvent): Awaitable { + return { type: 'proceed' } + } + + afterModelCall(_event: AfterModelCallEvent): Awaitable { + return { type: 'proceed' } + } +} diff --git a/strands-ts/src/interventions/index.ts b/strands-ts/src/interventions/index.ts new file mode 100644 index 0000000000..05463eb98d --- /dev/null +++ b/strands-ts/src/interventions/index.ts @@ -0,0 +1,5 @@ +export type { InterventionAction, LifecycleEvent, Proceed, Deny, Guide, Confirm, Transform } from './actions.js' +import { proceed, deny, guide, confirm, transform } from './actions.js' +export const InterventionActions = { proceed, deny, guide, confirm, transform } +export { InterventionHandler } from './handler.js' +export type { OnError } from './handler.js' diff --git a/strands-ts/src/interventions/registry.ts b/strands-ts/src/interventions/registry.ts new file mode 100644 index 0000000000..16e61520ea --- /dev/null +++ b/strands-ts/src/interventions/registry.ts @@ -0,0 +1,278 @@ +import { + BeforeInvocationEvent, + BeforeToolCallEvent, + AfterToolCallEvent, + BeforeModelCallEvent, + AfterModelCallEvent, + type HookableEvent, +} from '../hooks/events.js' +import type { HookRegistry } from '../hooks/registry.js' +import { HookOrder } from '../hooks/types.js' +import { Message, TextBlock } from '../types/messages.js' +import type { Guide, InterventionAction } from './actions.js' +import { defaultEvaluate } from './actions.js' +import { InterventionHandler } from './handler.js' +import { InterruptError } from '../interrupt.js' +import { logger } from '../logging/logger.js' +import type { JSONValue } from '../types/json.js' + +type LifecycleMethod = 'beforeInvocation' | 'beforeToolCall' | 'afterToolCall' | 'beforeModelCall' | 'afterModelCall' + +/** + * Bridges {@link InterventionHandler} instances and the Strands hook system. + * + * Registers one hook callback per lifecycle event type, then dispatches to + * all handlers that override that method — in registration order, with + * short-circuiting on Deny (and denied Confirms) and accumulation for Guide. + * + * See {@link InterventionAction} for the action-to-event compatibility matrix. + */ +export class InterventionRegistry { + private readonly _handlers: InterventionHandler[] + + constructor(handlers: InterventionHandler[], hookRegistry: HookRegistry) { + const seen = new Set() + for (const h of handlers) { + if (seen.has(h.name)) { + throw new Error(`Duplicate intervention handler name: '${h.name}'`) + } + seen.add(h.name) + } + this._handlers = handlers + this._registerHooks(hookRegistry) + } + + /** Registered handlers in registration order. */ + get handlers(): readonly InterventionHandler[] { + return this._handlers + } + + private _registerHooks(hookRegistry: HookRegistry): void { + if (this._handlers.some((h) => h.beforeInvocation !== InterventionHandler.prototype.beforeInvocation)) { + hookRegistry.addCallback(BeforeInvocationEvent, (e) => this._onBeforeInvocation(e), { + order: HookOrder.INTERVENTION_INPUT, + }) + } + if (this._handlers.some((h) => h.beforeToolCall !== InterventionHandler.prototype.beforeToolCall)) { + hookRegistry.addCallback(BeforeToolCallEvent, (e) => this._onBeforeToolCall(e), { + order: HookOrder.INTERVENTION_INPUT, + }) + } + if (this._handlers.some((h) => h.afterToolCall !== InterventionHandler.prototype.afterToolCall)) { + hookRegistry.addCallback(AfterToolCallEvent, (e) => this._onAfterToolCall(e), { + order: HookOrder.INTERVENTION_OUTPUT, + }) + } + if (this._handlers.some((h) => h.beforeModelCall !== InterventionHandler.prototype.beforeModelCall)) { + hookRegistry.addCallback(BeforeModelCallEvent, (e) => this._onBeforeModelCall(e), { + order: HookOrder.INTERVENTION_INPUT, + }) + } + if (this._handlers.some((h) => h.afterModelCall !== InterventionHandler.prototype.afterModelCall)) { + hookRegistry.addCallback(AfterModelCallEvent, (e) => this._onAfterModelCall(e), { + order: HookOrder.INTERVENTION_OUTPUT, + }) + } + } + + private async _onBeforeInvocation(event: BeforeInvocationEvent): Promise { + return this._dispatch(event, 'beforeInvocation', (action, handlerName) => { + switch (action.type) { + case 'deny': + event.cancel = `DENIED: ${action.reason}` + return true + case 'guide': + event.cancel = `GUIDANCE: ${action.feedback}` + return false + case 'transform': + action.apply(event) + return false + case 'proceed': + return false + default: + logger.warn(`handler=<${handlerName}>, event= | ${action.type} has no effect`) + return false + } + }) + } + + private async _onBeforeToolCall(event: BeforeToolCallEvent): Promise { + return this._dispatch(event, 'beforeToolCall', (action, handlerName) => { + const actionType = action.type + switch (actionType) { + case 'deny': + event.cancel = `DENIED: ${action.reason}` + return true + case 'confirm': { + // If response is provided, it's passed as a preemptive value to + // event.interrupt() — the interrupt is registered but never pauses. + // If no response, event.interrupt() throws InterruptError on first + // call (pausing the agent for external resume). + const result = event.interrupt({ + name: handlerName, + reason: action.prompt, + ...(action.response !== undefined && { response: action.response }), + }) + const check = action.evaluate ?? defaultEvaluate + if (!check(result)) { + event.cancel = `CONFIRMATION_FAILED: ${action.prompt}` + return true + } + return false + } + case 'guide': + event.cancel = `GUIDANCE: ${action.feedback}` + return false + case 'transform': + action.apply(event) + return false + case 'proceed': + return false + default: + logger.warn(`handler=<${handlerName}>, event= | ${actionType} has no effect`) + return false + } + }) + } + + private async _onAfterToolCall(event: AfterToolCallEvent): Promise { + return this._dispatch(event, 'afterToolCall', (action, handlerName) => { + switch (action.type) { + case 'transform': + action.apply(event) + return false + case 'proceed': + return false + default: + logger.warn(`handler=<${handlerName}>, event= | ${action.type} has no effect`) + return false + } + }) + } + + // Guide on beforeModelCall injects feedback as a user message so the model sees + // it on this call, rather than cancelling (which would end the invocation). + private async _onBeforeModelCall(event: BeforeModelCallEvent): Promise { + return this._dispatch(event, 'beforeModelCall', (action, handlerName) => { + switch (action.type) { + case 'deny': + event.cancel = `DENIED: ${action.reason}` + return true + case 'guide': + // Direct push bypasses MessageAddedEvent and conversation manager. + // This matches what plugins can do today via event.agent.messages. + event.agent.messages.push(new Message({ role: 'user', content: [new TextBlock(action.feedback)] })) + return false + case 'transform': + action.apply(event) + return false + case 'proceed': + return false + default: + logger.warn(`handler=<${handlerName}>, event= | ${action.type} has no effect`) + return false + } + }) + } + + private async _onAfterModelCall(event: AfterModelCallEvent): Promise { + return this._dispatch(event, 'afterModelCall', (action, handlerName) => { + switch (action.type) { + case 'guide': + event.retry = true + // Direct push bypasses MessageAddedEvent and conversation manager, so this + // message won't trigger context management and could push the context over + // the limit. LocalAgent doesn't expose a message-append method that goes + // through the hook pipeline. This matches what plugins can do today. + event.agent.messages.push(new Message({ role: 'user', content: [new TextBlock(action.feedback)] })) + return false + case 'transform': + action.apply(event) + return false + case 'proceed': + return false + default: + logger.warn(`handler=<${handlerName}>, event= | ${action.type} has no effect`) + return false + } + }) + } + + /** + * Iterate handlers in registration order and resolve the winning action. + * + * - Deny short-circuits immediately (remaining handlers are skipped). + * - Confirm pauses for human input; if denied short-circuits, if approved continues. + * - Guide feedback strings accumulate across handlers and are applied at the end. + * - Transform is applied in-place so later handlers see the mutation. + * - If a handler throws, behavior depends on {@link InterventionHandler.onError}: + * `'throw'` (default) rethrows, `'deny'` fails closed, `'proceed'` skips. + */ + private async _dispatch( + event: HookableEvent, + method: LifecycleMethod, + apply: (action: InterventionAction, handlerName: string) => boolean + ): Promise { + logger.debug(`event=<${method}> | dispatching to ${this._handlers.length} handler(s)`) + const guides: Array<{ handlerName: string; action: Guide }> = [] + + for (const handler of this._handlers) { + if (handler[method] === InterventionHandler.prototype[method]) continue + + logger.debug(`handler=<${handler.name}>, event=<${method}> | evaluating`) + + let action: InterventionAction | undefined + try { + action = await handler[method](event as never) + } catch (error) { + action = this._handleError(handler, method, error) + if (!action) continue + } + + logger.debug(`handler=<${handler.name}>, event=<${method}> | returned ${action.type}`) + + if (action.type === 'guide') { + guides.push({ handlerName: handler.name, action }) + } else { + try { + if (apply(action, handler.name)) { + logger.debug(`handler=<${handler.name}>, event=<${method}> | short-circuited`) + return + } + } catch (error) { + // InterruptError is intentional control flow (pauses the agent), + // not a handler failure. Always propagate regardless of onError. + if (error instanceof InterruptError) { + throw error + } + const errorAction = this._handleError(handler, method, error) + if (errorAction) { + if (apply(errorAction, handler.name)) { + return + } + } + } + } + } + + // Guide feedback accumulates across handlers. Only applied if + // no earlier handler short-circuited (deny/confirm). + if (guides.length > 0) { + logger.debug(`event=<${method}> | applying accumulated guide from ${guides.length} handler(s)`) + const feedback = guides.map((g) => `[${g.handlerName}] ${g.action.feedback}`).join('\n') + apply({ type: 'guide', feedback }, '') + } + } + + private _handleError(handler: InterventionHandler, method: string, error: unknown): InterventionAction | undefined { + const errorMsg = error instanceof Error ? error.message : String(error) + + if (handler.onError === 'throw') { + throw error + } else if (handler.onError === 'deny') { + return { type: 'deny', reason: `Handler threw: ${errorMsg}` } + } else { + return undefined + } + } +} diff --git a/strands-ts/src/logging/__tests__/logger.test.ts b/strands-ts/src/logging/__tests__/logger.test.ts new file mode 100644 index 0000000000..3b9c41067e --- /dev/null +++ b/strands-ts/src/logging/__tests__/logger.test.ts @@ -0,0 +1,96 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { configureLogging, logger } from '../logger.js' + +describe('configureLogging', () => { + let originalLogger: typeof logger + + beforeEach(() => { + // Store original logger + originalLogger = logger + }) + + afterEach(() => { + // Restore original logger + configureLogging(originalLogger) + }) + + it('allows custom logger injection', () => { + const customLogger = { + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + } + + configureLogging(customLogger) + + logger.debug('Debug message') + logger.info('Info message') + logger.warn('Warn message') + logger.error('Error message') + + expect(customLogger.debug).toHaveBeenCalledWith('Debug message') + expect(customLogger.info).toHaveBeenCalledWith('Info message') + expect(customLogger.warn).toHaveBeenCalledWith('Warn message') + expect(customLogger.error).toHaveBeenCalledWith('Error message') + }) + + it('passes multiple arguments to logger', () => { + const customLogger = { + debug: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + } + + configureLogging(customLogger) + + const obj = { key: 'value' } + const arr = [1, 2, 3] + logger.error('Error message', obj, arr, 123, true) + + expect(customLogger.error).toHaveBeenCalledWith('Error message', obj, arr, 123, true) + }) +}) + +describe('default logger', () => { + it('logs warnings to console.warn', () => { + const warnSpy = vi.spyOn(console, 'warn') + + logger.warn('Warning message', 'arg1', 'arg2') + + expect(warnSpy).toHaveBeenCalledWith('Warning message', 'arg1', 'arg2') + + warnSpy.mockRestore() + }) + + it('logs errors to console.error', () => { + const errorSpy = vi.spyOn(console, 'error') + + logger.error('Error message', 'arg1', 'arg2') + + expect(errorSpy).toHaveBeenCalledWith('Error message', 'arg1', 'arg2') + + errorSpy.mockRestore() + }) + + it('does not log debug messages', () => { + const debugSpy = vi.spyOn(console, 'debug') + + logger.debug('Debug message') + + expect(debugSpy).not.toHaveBeenCalled() + + debugSpy.mockRestore() + }) + + it('does not log info messages', () => { + const infoSpy = vi.spyOn(console, 'info') + + logger.info('Info message') + + expect(infoSpy).not.toHaveBeenCalled() + + infoSpy.mockRestore() + }) +}) diff --git a/strands-ts/src/logging/__tests__/warn-once.test.ts b/strands-ts/src/logging/__tests__/warn-once.test.ts new file mode 100644 index 0000000000..db1a2a7924 --- /dev/null +++ b/strands-ts/src/logging/__tests__/warn-once.test.ts @@ -0,0 +1,34 @@ +import { describe, it, expect, vi } from 'vitest' +import type { Logger } from '../types.js' +import { warnOnce } from '../warn-once.js' + +function createLogger(): Logger { + return { debug: vi.fn(), info: vi.fn(), warn: vi.fn(), error: vi.fn() } +} + +describe('warnOnce', () => { + it('emits a warning the first time a message is seen', () => { + const logger = createLogger() + warnOnce(logger, 'first-seen-msg') + expect(logger.warn).toHaveBeenCalledTimes(1) + expect(logger.warn).toHaveBeenCalledWith('first-seen-msg') + }) + + it('does not emit repeated warnings for the same message', () => { + const logger = createLogger() + warnOnce(logger, 'repeated-msg') + warnOnce(logger, 'repeated-msg') + warnOnce(logger, 'repeated-msg') + expect(logger.warn).toHaveBeenCalledTimes(1) + }) + + it('emits distinct messages independently', () => { + const logger = createLogger() + warnOnce(logger, 'distinct-alpha-msg') + warnOnce(logger, 'distinct-beta-msg') + warnOnce(logger, 'distinct-alpha-msg') + expect(logger.warn).toHaveBeenCalledTimes(2) + expect(logger.warn).toHaveBeenNthCalledWith(1, 'distinct-alpha-msg') + expect(logger.warn).toHaveBeenNthCalledWith(2, 'distinct-beta-msg') + }) +}) diff --git a/strands-ts/src/logging/index.ts b/strands-ts/src/logging/index.ts new file mode 100644 index 0000000000..6f81d5fc5c --- /dev/null +++ b/strands-ts/src/logging/index.ts @@ -0,0 +1,6 @@ +/** + * Logging module exports. + */ + +export { configureLogging, logger } from './logger.js' +export type { Logger } from './types.js' diff --git a/strands-ts/src/logging/logger.ts b/strands-ts/src/logging/logger.ts new file mode 100644 index 0000000000..e580113b19 --- /dev/null +++ b/strands-ts/src/logging/logger.ts @@ -0,0 +1,46 @@ +/** + * Logger configuration. + * + * This module provides simple logging infrastructure for the Strands SDK. + * Users can inject their own logger implementation to control logging behavior. + */ + +import type { Logger } from './types.js' + +/** + * Default logger implementation. + * + * Only logs warnings and errors to console. Debug and info are no-ops. + */ +const defaultLogger: Logger = { + debug: () => {}, + info: () => {}, + warn: (...args: unknown[]) => console.warn(...args), + error: (...args: unknown[]) => console.error(...args), +} + +/** + * Global logger instance. + */ +export let logger: Logger = defaultLogger + +/** + * Configures the global logger. + * + * Allows users to inject their own logger implementation (e.g., Pino, Winston) + * to control logging behavior, levels, and formatting. + * + * @param customLogger - The logger implementation to use + * + * @example + * ```typescript + * import pino from 'pino' + * import { configureLogging } from '@strands-agents/sdk' + * + * const logger = pino({ level: 'debug' }) + * configureLogging(logger) + * ``` + */ +export function configureLogging(customLogger: Logger): void { + logger = customLogger +} diff --git a/strands-ts/src/logging/types.ts b/strands-ts/src/logging/types.ts new file mode 100644 index 0000000000..1855246def --- /dev/null +++ b/strands-ts/src/logging/types.ts @@ -0,0 +1,30 @@ +/** + * Logging types for the Strands SDK. + */ + +/** + * Logger interface. + * + * Compatible with standard logging libraries like Pino, Winston, and console. + */ +export interface Logger { + /** + * Log a debug message. + */ + debug(...args: unknown[]): void + + /** + * Log an info message. + */ + info(...args: unknown[]): void + + /** + * Log a warning message. + */ + warn(...args: unknown[]): void + + /** + * Log an error message. + */ + error(...args: unknown[]): void +} diff --git a/strands-ts/src/logging/warn-once.ts b/strands-ts/src/logging/warn-once.ts new file mode 100644 index 0000000000..03b505f15e --- /dev/null +++ b/strands-ts/src/logging/warn-once.ts @@ -0,0 +1,19 @@ +import type { Logger } from './types.js' + +const warned = new Set() + +/** + * Emits a warning log at most once per unique message per process. + * + * Subsequent calls with the same message are no-ops, which prevents + * repeated nudges (e.g. "using default modelId") from flooding logs + * when many instances are constructed. + * + * @param logger - Logger to emit the warning on + * @param msg - Warning message; also used as the dedupe key + */ +export function warnOnce(logger: Logger, msg: string): void { + if (warned.has(msg)) return + logger.warn(msg) + warned.add(msg) +} diff --git a/strands-ts/src/mcp-config.ts b/strands-ts/src/mcp-config.ts new file mode 100644 index 0000000000..daa42ddb44 --- /dev/null +++ b/strands-ts/src/mcp-config.ts @@ -0,0 +1,188 @@ +import type { McpClientConfig, McpClientCredentials, McpClientOptions, McpTransport, TasksConfig } from './mcp.js' +import { logger } from './logging/index.js' + +/** + * Configuration for a single MCP server entry in a config file or object. + * + * Provide either `command` (stdio transport) or `url` (streamable-http/SSE), not both. + * When `transport` is omitted, it is auto-detected from the fields present. + */ +export interface McpServerConfig { + /** Command to spawn (stdio transport, supports `${VAR}` or `${env:VAR}` interpolation). */ + command?: string + /** Arguments passed to the command (supports `${VAR}` or `${env:VAR}` interpolation). */ + args?: string[] + /** Environment variables passed to the child process (supports `${VAR}` or `${env:VAR}` interpolation). */ + env?: Record + /** Working directory for the spawned process (supports `${VAR}` or `${env:VAR}` interpolation). */ + cwd?: string + /** Server endpoint URL (streamable-http or SSE transport, supports `${VAR}` or `${env:VAR}` interpolation). */ + url?: string + /** HTTP headers sent with every request (supports `${VAR}` or `${env:VAR}` interpolation). */ + headers?: Record + /** Explicit transport type. When omitted, auto-detected: `command` → stdio, `url` → streamable-http. */ + transport?: 'stdio' | 'sse' | 'streamable-http' + /** Client credentials for OAuth machine-to-machine auth (streamable-http only). */ + auth?: McpClientCredentials + /** When true, this server is skipped during loadServers. */ + disabled?: boolean + /** When true, config or connection failures skip this server instead of throwing. */ + continueOnError?: boolean + /** Task-augmented tool execution configuration (experimental). */ + tasksConfig?: TasksConfig +} + +/** + * Resolves an MCP servers config into an array of client configurations ready for instantiation. + * + * @param config - A file path to a JSON config, or a flat server map object. + * @param defaults - Options applied to all clients unless overridden per-server. + * @returns Resolved McpClientConfig array (one per enabled, successfully-resolved server). + */ +export async function resolveServerConfigs( + config: string | Record, + defaults?: McpClientOptions +): Promise { + const servers = await loadServersObject(config) + const results: McpClientConfig[] = [] + + for (const [name, server] of Object.entries(servers)) { + if (!server || typeof server !== 'object' || Array.isArray(server)) { + throw new Error(`Server "${name}" must be an object, got ${Array.isArray(server) ? 'array' : typeof server}`) + } + + if (server.disabled) continue + + const continueOnError = server.continueOnError ?? defaults?.continueOnError ?? false + + try { + if (server.command && server.url && !server.transport) { + throw new Error('Server config has both "command" and "url" — set "transport" explicitly or remove one') + } + + const type = server.transport ?? (server.command ? 'stdio' : server.url ? 'streamable-http' : undefined) + if (!type) throw new Error('Server config must include either "command" (stdio) or "url" (http)') + + let clientConfig: McpClientConfig + switch (type) { + case 'stdio': + clientConfig = await buildStdioConfig(server) + break + case 'streamable-http': + clientConfig = buildHttpConfig(server) + break + case 'sse': + clientConfig = await buildSseConfig(server) + break + default: { + const _exhaustive: never = type + throw new Error(`Unsupported transport type: ${_exhaustive}`) + } + } + + results.push({ ...baseOptions(name, server, defaults), ...clientConfig }) + } catch (error) { + if (!continueOnError) throw error + logger.warn(`server=<${name}>, error=<${error}> | MCP server config failed, skipping (continueOnError)`) + } + } + + return results +} + +async function buildStdioConfig(server: McpServerConfig): Promise { + if (!server.command) throw new Error('Stdio transport requires "command" field') + const { StdioClientTransport, getDefaultEnvironment } = await import('@modelcontextprotocol/sdk/client/stdio.js') + + const opts: ConstructorParameters[0] = { + command: interpolateEnv(server.command), + } + if (server.args) opts.args = server.args.map(interpolateEnv) + if (server.env) opts.env = { ...getDefaultEnvironment(), ...interpolateRecord(server.env) } + if (server.cwd) opts.cwd = interpolateEnv(server.cwd) + + return { transport: new StdioClientTransport(opts) as McpTransport } +} + +function buildHttpConfig(server: McpServerConfig): McpClientConfig { + if (!server.url) throw new Error('Streamable HTTP transport requires "url" field') + + const config: McpClientConfig = { url: interpolateEnv(server.url) } + if (server.headers) config.headers = interpolateRecord(server.headers) + if (server.auth) { + config.auth = { + clientId: interpolateEnv(server.auth.clientId), + clientSecret: interpolateEnv(server.auth.clientSecret), + ...(server.auth.scopes && { scopes: server.auth.scopes.map(interpolateEnv) }), + } + } + return config +} + +async function buildSseConfig(server: McpServerConfig): Promise { + if (!server.url) throw new Error('SSE transport requires "url" field') + if (server.auth) + throw new Error('SSE transport does not support auth — use streamable-http or provide a pre-configured transport') + + const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js') + const headers = server.headers ? interpolateRecord(server.headers) : undefined + + return { + transport: new SSEClientTransport( + new URL(interpolateEnv(server.url)), + headers ? { requestInit: { headers } } : undefined + ) as McpTransport, + } +} + +function baseOptions(name: string, server: McpServerConfig, defaults?: McpClientOptions): McpClientOptions { + const opts: McpClientOptions = { ...defaults, applicationName: defaults?.applicationName ?? name } + if (server.continueOnError != null) opts.continueOnError = server.continueOnError + if (server.tasksConfig != null) opts.tasksConfig = server.tasksConfig + return opts +} + +/** + * Replaces `$\{VAR\}` and `$\{env:VAR\}` placeholders with their process.env values. + * Throws if a referenced variable is not set. + * + * @example + * ```typescript + * interpolateEnv('Bearer $\{TOKEN\}') // → 'Bearer ghp_abc123' + * interpolateEnv('$\{env:HOME\}/config') // → '/home/user/config' + * ``` + */ +function interpolateEnv(value: string): string { + return value.replace(/\$\{(?:env:)?([^}]+)\}/g, (_, key: string) => { + const resolved = process.env[key] + if (resolved === undefined) throw new Error(`Environment variable "${key}" is not set`) + return resolved + }) +} + +/** Applies {@link interpolateEnv} to every value in a string record. */ +function interpolateRecord(record: Record): Record { + return Object.fromEntries(Object.entries(record).map(([k, v]) => [k, interpolateEnv(v)])) +} + +async function loadServersObject( + config: string | Record +): Promise> { + if (typeof config !== 'string') return config + + const { readFile } = await import('node:fs/promises') + const { homedir } = await import('node:os') + const { join } = await import('node:path') + + const filePath = config.startsWith('~/') ? join(homedir(), config.slice(2)) : config + const parsed = JSON.parse(await readFile(filePath, 'utf-8')) + const servers = parsed.mcpServers ?? parsed + + if (!servers || typeof servers !== 'object' || Array.isArray(servers)) { + throw new Error( + 'MCP config must be a JSON object mapping server names to configs, e.g. { "my-server": { "command": "node" } }' + ) + } + + return servers +} diff --git a/strands-ts/src/mcp.ts b/strands-ts/src/mcp.ts new file mode 100644 index 0000000000..436fdc7e95 --- /dev/null +++ b/strands-ts/src/mcp.ts @@ -0,0 +1,484 @@ +import { Client } from '@modelcontextprotocol/sdk/client/index.js' +import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' +import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' +import { ClientCredentialsProvider } from '@modelcontextprotocol/sdk/client/auth-extensions.js' +import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js' +import { takeResult } from '@modelcontextprotocol/sdk/shared/responseMessage.js' +import { + ElicitRequestSchema, + LoggingMessageNotificationSchema, + type ServerCapabilities, + type Implementation, + type LoggingMessageNotificationParams, +} from '@modelcontextprotocol/sdk/types.js' +import { context, propagation, trace } from '@opentelemetry/api' +import type { JSONSchema, JSONValue } from './types/json.js' +import type { ElicitationCallback } from './types/elicitation.js' +import { McpTool } from './tools/mcp-tool.js' +import { logger } from './logging/index.js' +import { type McpServerConfig, resolveServerConfigs } from './mcp-config.js' + +/** + * Widened transport type that accepts MCP transport implementations without requiring explicit casts. + * + * Under `exactOptionalPropertyTypes`, `StreamableHTTPClientTransport` is not directly assignable + * to `Transport` because its `sessionId` getter returns `string | undefined`, while `Transport` + * declares `sessionId?: string` (absent or string, but not explicitly undefined). + * This type relaxes that constraint so users can pass any MCP transport without `as Transport`. + */ +export type McpTransport = Omit & { sessionId?: string | undefined } + +/** Temporary placeholder for RuntimeConfig */ +export interface RuntimeConfig { + applicationName?: string + applicationVersion?: string +} + +/** + * Configuration for MCP task-augmented tool execution. + * + * WARNING: MCP Tasks is an experimental feature in both the MCP specification and this SDK. + * The API may change without notice in future versions. + * + * When provided to McpClient, enables task-based tool invocation which supports + * long-running tools with progress tracking. Without this config, tools are + * called directly without task management. + */ +export interface TasksConfig { + /** + * Time-to-live in milliseconds for task polling. + * Defaults to 60000 (60 seconds). + */ + ttl?: number + + /** + * Maximum time in milliseconds to wait for task completion during polling. + * Defaults to 300000 (5 minutes). + */ + pollTimeout?: number +} + +/** Connection state of an MCP client. */ +export type McpConnectionState = 'disconnected' | 'connected' | 'failed' + +/** Options for MCP tool invocation. */ +export interface McpCallToolOptions { + /** AbortSignal to cancel the in-flight request. */ + signal?: AbortSignal +} + +/** OAuth client credentials for machine-to-machine authentication. */ +export interface McpClientCredentials { + clientId: string + clientSecret: string + /** OAuth scopes to request. Joined with spaces before sending to the token endpoint. */ + scopes?: string[] +} + +/** Behavioral options shared by all MCP client configurations. */ +export interface McpClientOptions extends RuntimeConfig { + /** Disable OpenTelemetry MCP instrumentation. */ + disableMcpInstrumentation?: boolean + + /** + * Configuration for task-augmented tool execution (experimental). + * When provided (even as empty object), enables MCP task-based tool invocation. + * When undefined, tools are called directly without task management. + */ + tasksConfig?: TasksConfig + + /** + * Callback to handle server-initiated elicitation requests. + * When provided, the client advertises elicitation support (form + url modes) + * and routes incoming elicitation requests to this callback. + */ + elicitationCallback?: ElicitationCallback + + /** When true, connection failures are logged as warnings instead of throwing. */ + continueOnError?: boolean + + /** Called when the server emits a log message. Defaults to routing through the Strands logger. */ + logHandler?: (params: LoggingMessageNotificationParams) => void +} + +/** Arguments for configuring an MCP Client. */ +export type McpClientConfig = McpClientOptions & { + /** Pre-constructed transport. Mutually exclusive with `url`. */ + transport?: McpTransport + + /** Server URL. When provided, a StreamableHTTP transport is constructed automatically. */ + url?: string | URL + + /** Client credentials for OAuth machine-to-machine auth. Requires `url`. */ + auth?: McpClientCredentials + + /** Custom OAuth provider for advanced auth flows. Requires `url`. Mutually exclusive with `auth`. */ + authProvider?: OAuthClientProvider + + /** Custom headers to include on every request to the server. Requires `url`. */ + headers?: Record +} + +/** MCP Client for interacting with Model Context Protocol servers. */ +export class McpClient { + /** Default TTL for task polling in milliseconds (60 seconds). */ + public static readonly DEFAULT_TTL = 60000 + + /** Default poll timeout for task completion in milliseconds (5 minutes). */ + public static readonly DEFAULT_POLL_TIMEOUT = 300000 + + /** + * Parses an MCP servers config (file path or object) and returns McpClient instances. + * + * @param config - A file path to a JSON config, or a flat server map object. + * @param defaults - Options applied to all clients unless overridden per-server. + * @returns An array of McpClient instances ready to be passed to an Agent. + */ + public static async loadServers( + config: string | Record, + defaults?: McpClientOptions + ): Promise { + return (await resolveServerConfigs(config, defaults)).map((c) => new McpClient(c)) + } + + private _clientName: string + private _clientVersion: string + private _transport: Transport + private _state: McpConnectionState + private _client: Client + private _continueOnError: boolean + private _logHandler: (params: LoggingMessageNotificationParams) => void + private _disableMcpInstrumentation: boolean + private _tasksConfig: TasksConfig | undefined + private _elicitationCallback: ElicitationCallback | undefined + private _registeredToolNames = new Set() + private _onToolsChanged: ((oldTools: string[], newTools: McpTool[]) => void) | undefined + private _refreshingTools = false + private _pendingRefresh = false + + constructor(args: McpClientConfig) { + this._clientName = args.applicationName || 'strands-agents-ts-sdk' + this._clientVersion = args.applicationVersion || '0.0.1' + this._transport = McpClient._resolveTransport(args) + this._state = 'disconnected' + this._continueOnError = args.continueOnError ?? false + this._logHandler = args.logHandler ?? defaultLogHandler + this._tasksConfig = args.tasksConfig + this._elicitationCallback = args.elicitationCallback + this._client = new Client( + { + name: this._clientName, + version: this._clientVersion, + }, + { + ...(this._elicitationCallback ? { capabilities: { elicitation: { form: {}, url: {} } } } : undefined), + listChanged: { + tools: { + autoRefresh: false, + debounceMs: 300, + onChanged: (): void => { + this._handleToolsChanged() + }, + }, + }, + } + ) + + this._client.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => { + this._logHandler(notification.params) + }) + + this._disableMcpInstrumentation = args.disableMcpInstrumentation ?? false + } + + private static _resolveTransport(args: McpClientConfig): Transport { + if (args.transport && args.url) { + throw new Error('McpClientConfig: provide either "transport" or "url", not both') + } + if (!args.transport && !args.url) { + throw new Error('McpClientConfig: either "transport" or "url" must be provided') + } + if (args.transport) { + if (args.auth || args.authProvider || args.headers) { + throw new Error( + 'McpClientConfig: "auth", "authProvider", and "headers" require "url" (not compatible with "transport")' + ) + } + return args.transport as Transport + } + if (args.auth && args.authProvider) { + throw new Error('McpClientConfig: provide either "auth" or "authProvider", not both') + } + + const authProvider = args.auth + ? new ClientCredentialsProvider({ + clientId: args.auth.clientId, + clientSecret: args.auth.clientSecret, + ...(args.auth.scopes && { scope: args.auth.scopes.join(' ') }), + }) + : args.authProvider + + const url = args.url instanceof URL ? args.url : new URL(args.url!) + return new StreamableHTTPClientTransport(url, { + ...(authProvider && { authProvider }), + ...(args.headers && { requestInit: { headers: args.headers } }), + }) as Transport + } + + get client(): Client { + return this._client + } + + get serverCapabilities(): ServerCapabilities | undefined { + return this._client.getServerCapabilities() + } + + get serverVersion(): Implementation | undefined { + return this._client.getServerVersion() + } + + get serverInstructions(): string | undefined { + return this._client.getInstructions() + } + + get connectionState(): McpConnectionState { + return this._state + } + + get clientName(): string { + return this._clientName + } + + get continueOnError(): boolean { + return this._continueOnError + } + + /** + * Connects the MCP client to the server. + * + * Called lazily before any operation that requires a connection. When `continueOnError` is true, + * connection failures are swallowed and the client enters a `'failed'` state — subsequent + * calls are no-ops until `connect(true)` is called explicitly to retry. + * + * @param reconnect - When true, forces a reconnect even if already connected or failed. + * @returns A promise that resolves when the connection is established. + */ + public async connect(reconnect: boolean = false): Promise { + if (this._state !== 'disconnected' && !reconnect) return + + if (this._state === 'connected' && reconnect) { + await this._client.close() + this._state = 'disconnected' + } + + if (this._elicitationCallback) { + const callback = this._elicitationCallback + this._client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { + return await callback(extra, request.params) + }) + } + + try { + await this._client.connect(this._transport) + this._state = 'connected' + } catch (error) { + if (!this._continueOnError) throw error + this._state = 'failed' + logger.warn( + `client=<${this._clientName}>, error=<${error}> | MCP server failed to connect, continuing (continueOnError)` + ) + } + } + + /** + * Disconnects the MCP client from the server and cleans up resources. + * + * @returns A promise that resolves when the disconnection is complete. + */ + public async disconnect(): Promise { + // Must be done sequentially + await this._client.close() + await this._transport.close() + this._state = 'disconnected' + } + + /** + * Enables the `await using` pattern for automatic resource cleanup. + * Delegates to {@link McpClient.disconnect}. + */ + async [Symbol.asyncDispose](): Promise { + await this.disconnect() + } + + /** + * Lists the tools available on the server and returns them as executable McpTool instances. + * + * @returns A promise that resolves with an array of McpTool instances. + */ + public async listTools(): Promise { + await this.connect() + if (this._state === 'failed') return [] + + const tools: McpTool[] = [] + let cursor: string | undefined + + do { + const result = await this._client.listTools(cursor ? { cursor } : undefined) + + tools.push( + ...result.tools.map( + (toolSpec) => + new McpTool({ + name: toolSpec.name, + description: toolSpec.description || `Tool which performs ${toolSpec.name}`, + inputSchema: toolSpec.inputSchema as JSONSchema, + client: this, + }) + ) + ) + + cursor = result.nextCursor + } while (cursor) + + this._registeredToolNames = new Set(tools.map((t) => t.name)) + + return tools + } + + /** + * Sets a callback invoked when the MCP server's tool list changes at runtime. + * + * @param callback - Handler receiving the previous tool names and the refreshed tool instances, + * or undefined to remove the callback. + */ + set onToolsChanged(callback: ((oldTools: string[], newTools: McpTool[]) => void) | undefined) { + this._onToolsChanged = callback + } + + private async _handleToolsChanged(): Promise { + if (this._refreshingTools) { + this._pendingRefresh = true + return + } + this._refreshingTools = true + try { + do { + this._pendingRefresh = false + const oldTools = [...this._registeredToolNames] + const newTools = await this.listTools() + this._onToolsChanged?.(oldTools, newTools) + } while (this._pendingRefresh) + } catch (err) { + logger.warn( + `client=<${this._clientName}>, error=<${err}> | failed to refresh tools after toolsChanged notification` + ) + } finally { + this._refreshingTools = false + } + } + + /** + * Invoke a tool on the connected MCP server using an McpTool instance. + * + * When `tasksConfig` was provided to the client constructor, uses experimental + * task-based invocation which supports long-running tools with progress tracking. + * Otherwise, calls tools directly without task management. + * + * @param tool - The McpTool instance to invoke. + * @param args - The arguments to pass to the tool. + * @param options - Optional settings for the request. + * @returns A promise that resolves with the result of the tool invocation. + */ + public async callTool(tool: McpTool, args: JSONValue, options?: McpCallToolOptions): Promise { + await this.connect() + if (this._state === 'failed') throw new Error('MCP server failed to connect. Call connect(true) to retry.') + + if (args === null || args === undefined) { + return await this.callTool(tool, {}, options) + } + + if (typeof args !== 'object' || Array.isArray(args)) { + throw new Error( + `MCP Protocol Error: Tool arguments must be a JSON Object (named parameters). Received: ${Array.isArray(args) ? 'Array' : typeof args}` + ) + } + + // Inject OpenTelemetry trace context into tool arguments for distributed tracing + const enhancedArgs = this._disableMcpInstrumentation ? args : injectTraceContext(args) + const toolArgs = enhancedArgs as Record + + // When tasksConfig is undefined, call tools directly without task management + if (this._tasksConfig === undefined) { + return (await this._client.callTool({ name: tool.name, arguments: toolArgs }, undefined, options)) as JSONValue + } + + // When tasksConfig is defined (even as empty object), use task-based invocation + // which supports long-running tools with progress tracking + const stream = this._client.experimental.tasks.callToolStream({ name: tool.name, arguments: toolArgs }, undefined, { + timeout: this._tasksConfig.ttl ?? McpClient.DEFAULT_TTL, + maxTotalTimeout: this._tasksConfig.pollTimeout ?? McpClient.DEFAULT_POLL_TIMEOUT, + resetTimeoutOnProgress: true, + ...options, + }) + + const result = await takeResult(stream) + return result as JSONValue + } +} + +function defaultLogHandler(params: LoggingMessageNotificationParams): void { + const { level, logger: serverLogger, data } = params + const message = `logger=<${serverLogger ?? 'mcp'}>, data=<${JSON.stringify(data)}> | MCP server log` + if (level === 'debug') { + logger.debug(message) + } else if (level === 'info' || level === 'notice') { + logger.info(message) + } else if (level === 'warning') { + logger.warn(message) + } else { + logger.error(message) + } +} + +/** + * Carrier object for OpenTelemetry context propagation. + */ +interface ContextCarrier { + [key: string]: string | string[] | undefined +} + +/** + * Injects OpenTelemetry trace context into MCP tool call arguments. + * Returns the args with a `_meta` field containing W3C traceparent headers. + * If no active span exists or injection fails, returns the original args unchanged. + * + * @param args - The tool call arguments (must be a non-null object) + * @returns The args with trace context injected, or the original args on failure + */ +function injectTraceContext(args: JSONValue): JSONValue { + try { + const currentContext = context.active() + const currentSpan = trace.getSpan(currentContext) + + if (!currentSpan || !currentSpan.spanContext().traceId) { + return args + } + + const carrier: ContextCarrier = {} + propagation.inject(currentContext, carrier) + + const existingMeta = (args as Record)._meta + const mergedMeta = + existingMeta && typeof existingMeta === 'object' && !Array.isArray(existingMeta) + ? { ...existingMeta, ...carrier } + : carrier + + return { + ...(args as Record), + _meta: mergedMeta as unknown as JSONValue, + } + } catch (error) { + logger.warn(`error=<${error}> | failed to inject trace context into mcp tool call args`) + return args + } +} diff --git a/strands-ts/src/mime.ts b/strands-ts/src/mime.ts new file mode 100644 index 0000000000..4a4961238d --- /dev/null +++ b/strands-ts/src/mime.ts @@ -0,0 +1,95 @@ +/** + * MIME type utilities for media format detection and conversion. + * + * Provides bidirectional mapping between media formats and MIME types. + */ + +export const IMAGE_FORMATS = ['png', 'jpg', 'jpeg', 'gif', 'webp'] as const + +export type ImageFormat = (typeof IMAGE_FORMATS)[number] + +export type VideoFormat = 'mkv' | 'mov' | 'mp4' | 'webm' | 'flv' | 'mpeg' | 'mpg' | 'wmv' | '3gp' + +export type DocumentFormat = 'pdf' | 'csv' | 'doc' | 'docx' | 'xls' | 'xlsx' | 'html' | 'txt' | 'md' | 'json' | 'xml' + +export type MediaFormat = DocumentFormat | ImageFormat | VideoFormat + +const TO_MIME_TYPE: Record = { + // Images + png: 'image/png', + jpg: 'image/jpeg', + jpeg: 'image/jpeg', + gif: 'image/gif', + webp: 'image/webp', + // Videos + mkv: 'video/x-matroska', + mov: 'video/quicktime', + mp4: 'video/mp4', + webm: 'video/webm', + flv: 'video/x-flv', + mpeg: 'video/mpeg', + mpg: 'video/mpeg', + wmv: 'video/x-ms-wmv', + '3gp': 'video/3gpp', + // Documents + pdf: 'application/pdf', + csv: 'text/csv', + doc: 'application/msword', + docx: 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', + xls: 'application/vnd.ms-excel', + xlsx: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + html: 'text/html', + txt: 'text/plain', + md: 'text/markdown', + json: 'application/json', + xml: 'application/xml', +} + +const TO_MEDIA_FORMAT: Record = { + // Images + 'image/png': 'png', + 'image/jpeg': 'jpeg', + 'image/gif': 'gif', + 'image/webp': 'webp', + // Videos + 'video/x-matroska': 'mkv', + 'video/quicktime': 'mov', + 'video/mp4': 'mp4', + 'video/webm': 'webm', + 'video/x-flv': 'flv', + 'video/mpeg': 'mpeg', + 'video/x-ms-wmv': 'wmv', + 'video/3gpp': '3gp', + // Documents + 'application/pdf': 'pdf', + 'text/csv': 'csv', + 'application/msword': 'doc', + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'docx', + 'application/vnd.ms-excel': 'xls', + 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': 'xlsx', + 'text/html': 'html', + 'text/plain': 'txt', + 'text/markdown': 'md', + 'application/json': 'json', + 'application/xml': 'xml', +} + +/** + * Convert a media format to its MIME type. + * + * @param format - Media format (e.g., 'png', 'pdf') + * @returns MIME type string or undefined if not a known format + */ +export function toMimeType(format: string): string | undefined { + return TO_MIME_TYPE[format.toLowerCase() as MediaFormat] +} + +/** + * Convert a MIME type to its canonical media format. + * + * @param mimeType - MIME type string (e.g., 'image/png', 'application/pdf') + * @returns Media format or undefined if not a known MIME type + */ +export function toMediaFormat(mimeType: string): MediaFormat | undefined { + return TO_MEDIA_FORMAT[mimeType.toLowerCase()] +} diff --git a/strands-ts/src/models/__tests__/anthropic.test.ts b/strands-ts/src/models/__tests__/anthropic.test.ts new file mode 100644 index 0000000000..f00ccf09bf --- /dev/null +++ b/strands-ts/src/models/__tests__/anthropic.test.ts @@ -0,0 +1,914 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import Anthropic from '@anthropic-ai/sdk' +import { isNode } from '../../__fixtures__/environment.js' +import { AnthropicModel } from '../anthropic.js' +import { ContextWindowOverflowError, ModelThrottledError } from '../../errors.js' +import { collectIterator } from '../../__fixtures__/model-test-helpers.js' +import { + Message, + TextBlock, + CachePointBlock, + GuardContentBlock, + ToolResultBlock, + JsonBlock, +} from '../../types/messages.js' +import { ImageBlock, DocumentBlock, VideoBlock } from '../../types/media.js' +import { warnOnce } from '../../logging/warn-once.js' + +/** + * Helper to create a mock Anthropic client with streaming support + */ +function createMockClient(streamGenerator: () => AsyncGenerator): Anthropic { + return { + messages: { + stream: vi.fn(() => streamGenerator()), + countTokens: vi.fn(), + }, + } as unknown as Anthropic +} + +// Mock the Anthropic SDK +vi.mock('@anthropic-ai/sdk', () => { + const mockConstructor = vi.fn(function () { + return { + messages: { + stream: vi.fn(), + countTokens: vi.fn(), + }, + } + }) + return { + default: mockConstructor, + } +}) + +vi.mock('../../logging/warn-once.js', () => ({ + warnOnce: vi.fn(), +})) + +describe('AnthropicModel', () => { + beforeEach(() => { + vi.clearAllMocks() + if (isNode) { + vi.stubEnv('ANTHROPIC_API_KEY', 'sk-ant-test-env') + } + }) + + afterEach(() => { + vi.clearAllMocks() + if (isNode) { + vi.unstubAllEnvs() + } + }) + + describe('constructor', () => { + it('creates an instance with default configuration', () => { + const provider = new AnthropicModel({ apiKey: 'sk-ant-test' }) + const config = provider.getConfig() + expect(config.modelId).toBe('claude-sonnet-4-6') + expect(config.maxTokens).toBe(64_000) + }) + + it('uses provided model ID', () => { + const customModelId = 'claude-3-opus-20240229' + const provider = new AnthropicModel({ modelId: customModelId, apiKey: 'sk-ant-test' }) + expect(provider.getConfig().modelId).toBe(customModelId) + }) + + it('uses API key from constructor parameter', () => { + const apiKey = 'sk-explicit' + new AnthropicModel({ apiKey }) + expect(Anthropic).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey, + }) + ) + }) + + if (isNode) { + it('uses API key from environment variable', () => { + vi.stubEnv('ANTHROPIC_API_KEY', 'sk-from-env') + new AnthropicModel() + expect(Anthropic).toHaveBeenCalled() + }) + + it('throws error when no API key is available', () => { + vi.stubEnv('ANTHROPIC_API_KEY', '') + expect(() => new AnthropicModel()).toThrow('Anthropic API key is required') + }) + } + + it('uses provided client instance', () => { + const mockClient = {} as Anthropic + const provider = new AnthropicModel({ client: mockClient }) + expect(Anthropic).not.toHaveBeenCalled() + expect(provider).toBeDefined() + }) + + it('warns when maxTokens is not explicitly set', () => { + new AnthropicModel({ apiKey: 'sk-ant-test' }) + expect(warnOnce).toHaveBeenCalledWith( + expect.objectContaining({ warn: expect.any(Function) }), + expect.stringContaining('using default maxTokens') + ) + }) + + it('does not warn when maxTokens is explicitly set', () => { + new AnthropicModel({ apiKey: 'sk-ant-test', maxTokens: 4096 }) + expect(warnOnce).not.toHaveBeenCalledWith( + expect.objectContaining({ warn: expect.any(Function) }), + expect.stringContaining('using default maxTokens') + ) + }) + + it('warns when modelId is not explicitly set', () => { + new AnthropicModel({ apiKey: 'sk-ant-test' }) + expect(warnOnce).toHaveBeenCalledWith( + expect.objectContaining({ warn: expect.any(Function) }), + expect.stringContaining('using default modelId') + ) + }) + + it('does not warn when modelId is explicitly set', () => { + new AnthropicModel({ apiKey: 'sk-ant-test', modelId: 'claude-3-opus-20240229' }) + expect(warnOnce).not.toHaveBeenCalledWith( + expect.objectContaining({ warn: expect.any(Function) }), + expect.stringContaining('using default modelId') + ) + }) + + it('auto-populates contextWindowLimit from model ID lookup', () => { + const provider = new AnthropicModel({ apiKey: 'sk-test', modelId: 'claude-sonnet-4-20250514' }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'claude-sonnet-4-20250514', + maxTokens: 64_000, + contextWindowLimit: 1_000_000, + }) + }) + + it('auto-populates contextWindowLimit for default model ID', () => { + const provider = new AnthropicModel({ apiKey: 'sk-test' }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'claude-sonnet-4-6', + maxTokens: 64_000, + contextWindowLimit: 1_000_000, + }) + }) + + it('does not override explicit contextWindowLimit', () => { + const provider = new AnthropicModel({ + apiKey: 'sk-test', + modelId: 'claude-sonnet-4-20250514', + contextWindowLimit: 100_000, + }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'claude-sonnet-4-20250514', + maxTokens: 64_000, + contextWindowLimit: 100_000, + }) + }) + + it('leaves contextWindowLimit undefined for unknown model IDs', () => { + const provider = new AnthropicModel({ apiKey: 'sk-test', modelId: 'unknown-model' }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'unknown-model', + maxTokens: 64_000, + }) + }) + }) + + describe('updateConfig', () => { + it('merges new config with existing config', () => { + const provider = new AnthropicModel({ apiKey: 'sk-test', temperature: 0.5 }) + provider.updateConfig({ temperature: 0.8, maxTokens: 8192 }) + expect(provider.getConfig()).toMatchObject({ + temperature: 0.8, + maxTokens: 8192, + }) + }) + + it('re-resolves contextWindowLimit when modelId changes and it was auto-resolved', () => { + const provider = new AnthropicModel({ apiKey: 'sk-test' }) + expect(provider.getConfig().contextWindowLimit).toBe(1_000_000) // claude-sonnet-4-6 default + + provider.updateConfig({ modelId: 'claude-sonnet-4-20250514' }) + expect(provider.getConfig().contextWindowLimit).toBe(1_000_000) // claude-sonnet-4-20250514 value + }) + + it('preserves explicit contextWindowLimit when modelId changes', () => { + const provider = new AnthropicModel({ apiKey: 'sk-test', contextWindowLimit: 50_000 }) + expect(provider.getConfig().contextWindowLimit).toBe(50_000) + + provider.updateConfig({ modelId: 'claude-sonnet-4-20250514' }) + expect(provider.getConfig().contextWindowLimit).toBe(50_000) // preserved + }) + }) + + describe('stream event handling', () => { + it('yields correct event sequence for simple text response', async () => { + const mockClient = createMockClient(async function* () { + yield { type: 'message_start', message: { role: 'assistant', usage: { input_tokens: 10 } } } + yield { type: 'content_block_start', index: 0, content_block: { type: 'text', text: '' } } + yield { type: 'content_block_delta', index: 0, delta: { type: 'text_delta', text: 'Hello' } } + yield { type: 'content_block_stop', index: 0 } + yield { type: 'message_delta', delta: { stop_reason: 'end_turn' }, usage: { output_tokens: 5 } } + yield { type: 'message_stop' } + }) + + const provider = new AnthropicModel({ client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const events = await collectIterator(provider.stream(messages)) + + expect(events).toHaveLength(6) + expect(events[0]).toEqual({ type: 'modelMessageStartEvent', role: 'assistant' }) + expect(events[1]).toEqual({ type: 'modelContentBlockStartEvent' }) + expect(events[2]).toEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'Hello' }, + }) + expect(events[3]).toEqual({ type: 'modelContentBlockStopEvent' }) + expect(events[4]).toEqual({ + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + }) + expect(events[5]).toEqual({ type: 'modelMessageStopEvent', stopReason: 'endTurn' }) + }) + + it('handles tool use events', async () => { + const mockClient = createMockClient(async function* () { + yield { type: 'message_start', message: { role: 'assistant', usage: { input_tokens: 10 } } } + yield { + type: 'content_block_start', + index: 0, + content_block: { type: 'tool_use', id: 'tool_1', name: 'calc' }, + } + yield { type: 'content_block_delta', index: 0, delta: { type: 'input_json_delta', partial_json: '{"a"' } } + yield { type: 'content_block_delta', index: 0, delta: { type: 'input_json_delta', partial_json: ':1}' } } + yield { type: 'content_block_stop', index: 0 } + yield { type: 'message_delta', delta: { stop_reason: 'tool_use' }, usage: { output_tokens: 10 } } + yield { type: 'message_stop' } + }) + + const provider = new AnthropicModel({ client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const events = await collectIterator(provider.stream(messages)) + + expect(events).toContainEqual({ + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', name: 'calc', toolUseId: 'tool_1' }, + }) + expect(events).toContainEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: '{"a"' }, + }) + expect(events).toContainEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: ':1}' }, + }) + expect(events).toContainEqual({ type: 'modelMessageStopEvent', stopReason: 'toolUse' }) + }) + + it.each([ + ['pause_turn', 'pauseTurn'], + ['refusal', 'refusal'], + ])('maps anthropic stop reason "%s" to "%s"', async (anthropicReason, expected) => { + const mockClient = createMockClient(async function* () { + yield { type: 'message_start', message: { role: 'assistant', usage: { input_tokens: 1 } } } + yield { type: 'message_delta', delta: { stop_reason: anthropicReason }, usage: { output_tokens: 1 } } + yield { type: 'message_stop' } + }) + + const provider = new AnthropicModel({ client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const events = await collectIterator(provider.stream(messages)) + + expect(events).toContainEqual({ type: 'modelMessageStopEvent', stopReason: expected }) + }) + + it('handles thinking/reasoning events', async () => { + const mockClient = createMockClient(async function* () { + yield { type: 'message_start', message: { role: 'assistant', usage: { input_tokens: 10 } } } + // Thinking block + yield { type: 'content_block_start', index: 0, content_block: { type: 'thinking', thinking: '' } } + yield { type: 'content_block_delta', index: 0, delta: { type: 'thinking_delta', thinking: 'Hmm...' } } + yield { type: 'content_block_delta', index: 0, delta: { type: 'signature_delta', signature: 'sig_123' } } + yield { type: 'content_block_stop', index: 0 } + // Text block + yield { type: 'content_block_start', index: 1, content_block: { type: 'text', text: '' } } + yield { type: 'content_block_delta', index: 1, delta: { type: 'text_delta', text: 'Answer' } } + yield { type: 'content_block_stop', index: 1 } + + yield { type: 'message_delta', delta: { stop_reason: 'end_turn' }, usage: { output_tokens: 20 } } + yield { type: 'message_stop' } + }) + + const provider = new AnthropicModel({ client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const events = await collectIterator(provider.stream(messages)) + + // Check for thinking deltas + expect(events).toContainEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'reasoningContentDelta', text: 'Hmm...' }, + }) + expect(events).toContainEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'reasoningContentDelta', signature: 'sig_123' }, + }) + }) + + it('handles redacted thinking events', async () => { + const mockClient = createMockClient(async function* () { + yield { type: 'message_start', message: { role: 'assistant', usage: { input_tokens: 10 } } } + yield { + type: 'content_block_start', + index: 0, + content_block: { type: 'redacted_thinking', data: 'data' }, + } + yield { type: 'content_block_stop', index: 0 } + yield { type: 'message_delta', delta: { stop_reason: 'end_turn' }, usage: { output_tokens: 5 } } + yield { type: 'message_stop' } + }) + + const provider = new AnthropicModel({ client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const events = await collectIterator(provider.stream(messages)) + + expect(events).toContainEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'reasoningContentDelta', redactedContent: 'data' }, + }) + }) + + it('handles text payload directly in content_block_start (optimization)', async () => { + const mockClient = createMockClient(async function* () { + yield { type: 'message_start', message: { role: 'assistant', usage: { input_tokens: 10 } } } + yield { type: 'content_block_start', index: 0, content_block: { type: 'text', text: 'Full text' } } + yield { type: 'content_block_stop', index: 0 } + yield { type: 'message_delta', delta: { stop_reason: 'end_turn' }, usage: { output_tokens: 5 } } + yield { type: 'message_stop' } + }) + + const provider = new AnthropicModel({ client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const events = await collectIterator(provider.stream(messages)) + + expect(events).toContainEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'Full text' }, + }) + }) + + it('handles error during stream', async () => { + const mockClient = createMockClient(async function* () { + yield { type: 'ping' } // Satisfy linter require-yield + throw new Error('API Error') + }) + const provider = new AnthropicModel({ client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(collectIterator(provider.stream(messages))).rejects.toThrow('API Error') + }) + + it.each([ + 'PROMPT IS TOO LONG: request exceeds context window', + 'max_tokens exceeded', + 'input too long', + 'input is too long', + 'input length exceeds context window', + 'input and output tokens exceed your context limit', + ])('maps context overflow error "%s" to ContextWindowOverflowError', async (message) => { + const mockClient = createMockClient(async function* () { + yield { type: 'ping' } // Satisfy linter require-yield + throw new Error(message) + }) + const provider = new AnthropicModel({ client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(collectIterator(provider.stream(messages))).rejects.toThrow(ContextWindowOverflowError) + }) + + it('maps HTTP 429 error to ModelThrottledError', async () => { + const rateLimitError = Object.assign(new Error('Rate limit exceeded'), { status: 429 }) + // eslint-disable-next-line require-yield + const mockClient = createMockClient(async function* () { + throw rateLimitError + }) + const provider = new AnthropicModel({ client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(collectIterator(provider.stream(messages))).rejects.toThrow(ModelThrottledError) + await expect(collectIterator(provider.stream(messages))).rejects.toThrow('Rate limit exceeded') + }) + }) + + describe('request formatting', () => { + // Helper to capture request arguments + const setupCapture = () => { + const captured: { request: any; options: any } = { request: null, options: null } + const mockClient = { + messages: { + stream: vi.fn((req, opts) => { + captured.request = req + captured.options = opts + return (async function* () {})() + }), + }, + } as any + return { captured, mockClient } + } + + it('formats basic request correctly', async () => { + const { captured, mockClient } = setupCapture() + const provider = new AnthropicModel({ + modelId: 'claude-3-opus', + maxTokens: 1000, + temperature: 0.7, + client: mockClient, + }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + await collectIterator(provider.stream(messages)) + + expect(captured.request).toEqual({ + model: 'claude-3-opus', + max_tokens: 1000, + temperature: 0.7, + messages: [{ role: 'user', content: [{ type: 'text', text: 'Hello' }] }], + stream: true, + }) + }) + + it('formats tools correctly', async () => { + const { captured, mockClient } = setupCapture() + const provider = new AnthropicModel({ client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + const toolSpecs = [ + { + name: 'calc', + description: 'calculate', + inputSchema: { type: 'object' as const, properties: {} }, + }, + ] + + await collectIterator(provider.stream(messages, { toolSpecs, toolChoice: { auto: {} } })) + + expect(captured.request.tools).toHaveLength(1) + expect(captured.request.tools[0]).toEqual({ + name: 'calc', + description: 'calculate', + input_schema: { type: 'object', properties: {} }, + }) + expect(captured.request.tool_choice).toEqual({ type: 'auto' }) + }) + + describe('Prompt Caching (Lookahead logic)', () => { + it('attaches cache control to message content block followed by cache point', async () => { + const { captured, mockClient } = setupCapture() + const provider = new AnthropicModel({ client: mockClient }) + const messages = [ + new Message({ + role: 'user', + content: [ + new TextBlock('Cached content'), + // Use 'default' here; provider converts it to 'ephemeral' for Anthropic + new CachePointBlock({ cacheType: 'default' }), + new TextBlock('Non-cached content'), + ], + }), + ] + + await collectIterator(provider.stream(messages)) + + const content = captured.request.messages[0].content + expect(content).toHaveLength(2) // 3 blocks reduced to 2 (cache point merged) + expect(content[0]).toEqual({ + type: 'text', + text: 'Cached content', + cache_control: { type: 'ephemeral' }, + }) + expect(content[1]).toEqual({ + type: 'text', + text: 'Non-cached content', + }) + }) + + it('formats system prompt string without cache', async () => { + const { captured, mockClient } = setupCapture() + const provider = new AnthropicModel({ client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await collectIterator(provider.stream(messages, { systemPrompt: 'System instruction' })) + + expect(captured.request.system).toBe('System instruction') + }) + + it('formats system prompt array with cache points', async () => { + const { captured, mockClient } = setupCapture() + const provider = new AnthropicModel({ client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + const systemPrompt = [ + new TextBlock('Heavy context'), + new CachePointBlock({ cacheType: 'default' }), + new TextBlock('Light context'), + ] + + await collectIterator(provider.stream(messages, { systemPrompt })) + + expect(Array.isArray(captured.request.system)).toBe(true) + const system = captured.request.system + expect(system).toHaveLength(2) + expect(system[0]).toEqual({ + type: 'text', + text: 'Heavy context', + cache_control: { type: 'ephemeral' }, + }) + expect(system[1]).toEqual({ + type: 'text', + text: 'Light context', + }) + }) + }) + + describe('Media blocks', () => { + it('formats images correctly', async () => { + const { captured, mockClient } = setupCapture() + const provider = new AnthropicModel({ client: mockClient }) + const imageBytes = new Uint8Array([72, 101, 108, 108, 111]) // "Hello" + const messages = [ + new Message({ + role: 'user', + content: [ + new ImageBlock({ + format: 'png', + source: { bytes: imageBytes }, + }), + ], + }), + ] + + await collectIterator(provider.stream(messages)) + + const content = captured.request.messages[0].content[0] + expect(content.type).toBe('image') + expect(content.source.media_type).toBe('image/png') + // Base64 of "Hello" is "SGVsbG8=" + expect(content.source.data).toBe('SGVsbG8=') + }) + + it('formats PDFs correctly', async () => { + const { captured, mockClient } = setupCapture() + const provider = new AnthropicModel({ client: mockClient }) + const pdfBytes = new Uint8Array([1, 2, 3]) + const messages = [ + new Message({ + role: 'user', + content: [ + new DocumentBlock({ + name: 'doc.pdf', + format: 'pdf', + source: { bytes: pdfBytes }, + }), + ], + }), + ] + + await collectIterator(provider.stream(messages)) + + const content = captured.request.messages[0].content[0] + expect(content.type).toBe('document') + expect(content.source.media_type).toBe('application/pdf') + expect(content.title).toBe('doc.pdf') + }) + + it('logs warning for unsupported GuardContentBlock in user message', async () => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) // Spy on console.warn (via logger) + const { captured, mockClient } = setupCapture() + const provider = new AnthropicModel({ client: mockClient }) + const messages = [ + new Message({ + role: 'user', + content: [ + new GuardContentBlock({ + text: { text: 'guard', qualifiers: ['query'] }, + }), + ], + }), + ] + + await collectIterator(provider.stream(messages)) + + // Should result in empty content if blocked + expect(captured.request.messages[0].content).toHaveLength(0) + warnSpy.mockRestore() + }) + }) + + describe('Tool Results', () => { + it('formats simple text tool result', async () => { + const { captured, mockClient } = setupCapture() + const provider = new AnthropicModel({ client: mockClient }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 't1', + status: 'success', + content: [new TextBlock('42')], + }), + ], + }), + ] + + await collectIterator(provider.stream(messages)) + + const content = captured.request.messages[0].content[0] + expect(content.type).toBe('tool_result') + expect(content.tool_use_id).toBe('t1') + expect(content.content).toBe('42') // Simplified to string + expect(content.is_error).toBe(false) + }) + + it('formats mixed tool result (json/image)', async () => { + const { captured, mockClient } = setupCapture() + const provider = new AnthropicModel({ client: mockClient }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 't1', + status: 'error', + content: [new JsonBlock({ json: { error: 'failed' } }), new TextBlock('Details here')], + }), + ], + }), + ] + + await collectIterator(provider.stream(messages)) + + const content = captured.request.messages[0].content[0] + expect(content.type).toBe('tool_result') + expect(content.is_error).toBe(true) + expect(Array.isArray(content.content)).toBe(true) + // JSON is stringified in Anthropic tool result content + expect(content.content[0]).toEqual({ type: 'text', text: '{"error":"failed"}' }) + expect(content.content[1]).toEqual({ type: 'text', text: 'Details here' }) + }) + + it('formats image block inside tool result via recursive formatting', async () => { + const { captured, mockClient } = setupCapture() + const provider = new AnthropicModel({ client: mockClient }) + const imageBytes = new Uint8Array([72, 101, 108, 108, 111]) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 't1', + status: 'success', + content: [ + new TextBlock('Here is the screenshot'), + new ImageBlock({ format: 'png', source: { bytes: imageBytes } }), + ], + }), + ], + }), + ] + + await collectIterator(provider.stream(messages)) + + const content = captured.request.messages[0].content[0] + expect(content.type).toBe('tool_result') + expect(Array.isArray(content.content)).toBe(true) + expect(content.content[0]).toEqual({ type: 'text', text: 'Here is the screenshot' }) + expect(content.content[1]).toEqual({ + type: 'image', + source: { type: 'base64', media_type: 'image/png', data: 'SGVsbG8=' }, + }) + }) + + it('formats document block inside tool result as text for text formats', async () => { + const { captured, mockClient } = setupCapture() + const provider = new AnthropicModel({ client: mockClient }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 't1', + status: 'success', + content: [new DocumentBlock({ name: 'data.json', format: 'json', source: { text: '{"key":"val"}' } })], + }), + ], + }), + ] + + await collectIterator(provider.stream(messages)) + + const content = captured.request.messages[0].content[0] + expect(content.type).toBe('tool_result') + // Single text item collapses to string + expect(content.content).toBe('{"key":"val"}') + }) + + it('skips video block inside tool result with warning', async () => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const { captured, mockClient } = setupCapture() + const provider = new AnthropicModel({ client: mockClient }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 't1', + status: 'success', + content: [ + new TextBlock('result'), + new VideoBlock({ format: 'mp4', source: { bytes: new Uint8Array([1]) } }), + ], + }), + ], + }), + ] + + await collectIterator(provider.stream(messages)) + + const content = captured.request.messages[0].content[0] + expect(content.type).toBe('tool_result') + // Video is filtered out, single text collapses to string + expect(content.content).toBe('result') + expect(warnSpy).toHaveBeenCalled() + warnSpy.mockRestore() + }) + }) + + describe('Beta headers', () => { + it('does not pass per-request options when betas is unset', async () => { + const { captured, mockClient } = setupCapture() + const provider = new AnthropicModel({ client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await collectIterator(provider.stream(messages)) + + expect(captured.options).toBeUndefined() + }) + + it('forwards configured betas as a per-request anthropic-beta header', async () => { + const { captured, mockClient } = setupCapture() + const provider = new AnthropicModel({ + client: mockClient, + betas: ['interleaved-thinking-2025-05-14', 'mcp-client-2025-11-20'], + }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await collectIterator(provider.stream(messages)) + + expect(captured.options).toEqual({ + headers: { 'anthropic-beta': 'interleaved-thinking-2025-05-14,mcp-client-2025-11-20' }, + }) + }) + + it('reflects updateConfig({ betas }) on the next request', async () => { + const { captured, mockClient } = setupCapture() + const provider = new AnthropicModel({ client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await collectIterator(provider.stream(messages)) + expect(captured.options).toBeUndefined() + + provider.updateConfig({ betas: ['interleaved-thinking-2025-05-14'] }) + await collectIterator(provider.stream(messages)) + + expect(captured.options).toEqual({ + headers: { 'anthropic-beta': 'interleaved-thinking-2025-05-14' }, + }) + }) + }) + }) + + describe('countTokens', () => { + const messages: Message[] = [new Message({ role: 'user', content: [new TextBlock('hello')] })] + const toolSpecs = [ + { name: 'test_tool', description: 'A test tool', inputSchema: { type: 'object' as const, properties: {} } }, + ] + + function createCountTokensClient(mockCountTokens: ReturnType): Anthropic { + return { + messages: { + stream: vi.fn(), + countTokens: mockCountTokens, + }, + } as unknown as Anthropic + } + + it('should use heuristic by default when useNativeTokenCount is not set', async () => { + const mockCountTokens = vi.fn() + const client = createCountTokensClient(mockCountTokens) + const model = new AnthropicModel({ client, modelId: 'claude-sonnet-4-6' }) + + const result = await model.countTokens(messages) + + expect(mockCountTokens).not.toHaveBeenCalled() + expect(result).toBe(2) // heuristic: Math.ceil('hello'.length / 4) + }) + + it('should return native token count on success', async () => { + const mockCountTokens = vi.fn(async () => ({ input_tokens: 42 })) + const client = createCountTokensClient(mockCountTokens) + const model = new AnthropicModel({ client, modelId: 'claude-sonnet-4-6', useNativeTokenCount: true }) + + const result = await model.countTokens(messages) + + expect(result).toBe(42) + expect(mockCountTokens).toHaveBeenCalledOnce() + }) + + it('should include system prompt in request', async () => { + const mockCountTokens = vi.fn(async () => ({ input_tokens: 55 })) + const client = createCountTokensClient(mockCountTokens) + const model = new AnthropicModel({ client, modelId: 'claude-sonnet-4-6', useNativeTokenCount: true }) + + const result = await model.countTokens(messages, { systemPrompt: 'Be helpful.' }) + + expect(result).toBe(55) + expect(mockCountTokens).toHaveBeenCalledWith({ + model: 'claude-sonnet-4-6', + messages: [{ role: 'user', content: [{ type: 'text', text: 'hello' }] }], + system: 'Be helpful.', + }) + }) + + it('should include tool specs in request', async () => { + const mockCountTokens = vi.fn(async () => ({ input_tokens: 100 })) + const client = createCountTokensClient(mockCountTokens) + const model = new AnthropicModel({ client, modelId: 'claude-sonnet-4-6', useNativeTokenCount: true }) + + const result = await model.countTokens(messages, { toolSpecs }) + + expect(result).toBe(100) + expect(mockCountTokens).toHaveBeenCalledWith({ + model: 'claude-sonnet-4-6', + messages: [{ role: 'user', content: [{ type: 'text', text: 'hello' }] }], + tools: [{ name: 'test_tool', description: 'A test tool', input_schema: { type: 'object', properties: {} } }], + }) + }) + + it('should strip max_tokens from request', async () => { + const mockCountTokens = vi.fn(async () => ({ input_tokens: 10 })) + const client = createCountTokensClient(mockCountTokens) + const model = new AnthropicModel({ client, modelId: 'claude-sonnet-4-6', useNativeTokenCount: true }) + + await model.countTokens(messages) + + expect(mockCountTokens).toHaveBeenCalledWith({ + model: 'claude-sonnet-4-6', + messages: [{ role: 'user', content: [{ type: 'text', text: 'hello' }] }], + }) + }) + + it('should fall back to estimation on API error', async () => { + const mockCountTokens = vi.fn(async () => { + throw new Error('Unsupported') + }) + const client = createCountTokensClient(mockCountTokens) + const model = new AnthropicModel({ client, modelId: 'claude-sonnet-4-6', useNativeTokenCount: true }) + + const result = await model.countTokens(messages) + + expect(typeof result).toBe('number') + expect(result).toBeGreaterThanOrEqual(0) + }) + + it('should fall back to estimation on generic exception', async () => { + const mockCountTokens = vi.fn(async () => { + throw new Error('Connection failed') + }) + const client = createCountTokensClient(mockCountTokens) + const model = new AnthropicModel({ client, modelId: 'claude-sonnet-4-6', useNativeTokenCount: true }) + + const result = await model.countTokens(messages) + + expect(typeof result).toBe('number') + expect(result).toBeGreaterThanOrEqual(0) + }) + + it('should skip native API and use heuristic when useNativeTokenCount is false', async () => { + const mockCountTokens = vi.fn() + const client = createCountTokensClient(mockCountTokens) + const model = new AnthropicModel({ client, modelId: 'claude-sonnet-4-6', useNativeTokenCount: false }) + + const result = await model.countTokens(messages) + + expect(mockCountTokens).not.toHaveBeenCalled() + expect(result).toBe(2) // heuristic: Math.ceil('hello'.length / 4) + }) + }) +}) diff --git a/strands-ts/src/models/__tests__/bedrock.test.ts b/strands-ts/src/models/__tests__/bedrock.test.ts new file mode 100644 index 0000000000..02e5313cdd --- /dev/null +++ b/strands-ts/src/models/__tests__/bedrock.test.ts @@ -0,0 +1,4518 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { + BedrockRuntimeClient, + ConverseStreamCommand, + CountTokensCommand, + ValidationException, +} from '@aws-sdk/client-bedrock-runtime' +import { isNode } from '../../__fixtures__/environment.js' +import { BedrockModel } from '../bedrock.js' +import { ContextWindowOverflowError, ModelThrottledError } from '../../errors.js' +import { Message, ReasoningBlock, ToolUseBlock, ToolResultBlock, JsonBlock } from '../../types/messages.js' +import type { SystemContentBlock } from '../../types/messages.js' +import { TextBlock, GuardContentBlock, CachePointBlock } from '../../types/messages.js' +import { ImageBlock, VideoBlock, DocumentBlock } from '../../types/media.js' +import { CitationsBlock } from '../../types/citations.js' +import type { StreamOptions } from '../model.js' +import { collectIterator } from '../../__fixtures__/model-test-helpers.js' +import { NOOP_TOOL_SPEC } from '../../tools/noop-tool.js' +import { warnOnce } from '../../logging/warn-once.js' + +/** + * Helper function to mock BedrockRuntimeClient implementation with customizable config. + * @param options - Optional configuration for mock region, useFipsEndpoint, and send functions + */ +function mockBedrockClientImplementation(options?: { + region?: () => Promise + useFipsEndpoint?: () => Promise + send?: (...args: unknown[]) => Promise +}): void { + const mockSend = vi.fn( + options?.send ?? + (async () => { + throw new Error('send() not mocked - specify send option if needed') + }) + ) + + vi.mocked(BedrockRuntimeClient).mockImplementation(function (...args: unknown[]) { + // Extract region from constructor args if provided + const clientConfig = (args[0] as { region?: string } | undefined) ?? {} + const configuredRegion = clientConfig.region + + const mockRegion = vi.fn( + options?.region ?? + (async () => { + // If region was explicitly configured in constructor, return it; otherwise return default + if (configuredRegion) return configuredRegion + return 'us-east-1' + }) + ) + const mockUseFipsEndpoint = vi.fn(options?.useFipsEndpoint ?? (async () => false)) + + return { + send: mockSend, + middlewareStack: { add: vi.fn() }, + config: { + region: mockRegion, + useFipsEndpoint: mockUseFipsEndpoint, + }, + } as never + } as never) +} + +/** + * Helper function to setup mock send with custom stream generator. + */ +function setupMockSend(streamGenerator: () => AsyncGenerator): void { + vi.clearAllMocks() + const mockSend = vi.fn( + async (): Promise<{ stream: AsyncIterable }> => ({ + stream: streamGenerator(), + }) + ) + mockBedrockClientImplementation({ send: mockSend }) +} + +// Mock the AWS SDK +vi.mock('@aws-sdk/client-bedrock-runtime', async (importOriginal) => { + const originalModule = await importOriginal() + + // Mock command classes that the code under test will instantiate + const ConverseStreamCommand = vi.fn() + const ConverseCommand = vi.fn() + + const mockSend = vi.fn(async (command: unknown) => { + // Check which constructor was used to create the command object + if (command instanceof ConverseStreamCommand) { + // Return a streaming response + return { + stream: (async function* (): AsyncGenerator { + yield { messageStart: { role: 'assistant' } } + yield { contentBlockStart: {} } + yield { contentBlockDelta: { delta: { text: 'Hello' } } } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'end_turn' } } + yield { + metadata: { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + metrics: { latencyMs: 100 }, + }, + } + })(), + } + } + + if (command instanceof ConverseCommand) { + // Return a non-streaming (full) response for the non-streaming API + return { + output: { + message: { + role: 'assistant', + content: [{ text: 'Hello' }], + }, + }, + stopReason: 'end_turn', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + metrics: { latencyMs: 100 }, + } + } + + throw new Error('Unhandled command type in mock') + }) + + // Create a mock CountTokensCommand class + const CountTokensCommand = vi.fn() + + // Create a mock ValidationException class + class MockValidationException extends Error { + constructor(opts: { message: string; $metadata: Record }) { + super(opts.message) + this.name = 'ValidationException' + } + } + + return { + ...originalModule, + BedrockRuntimeClient: vi.fn(function () { + return { + send: mockSend, + middlewareStack: { add: vi.fn() }, + config: { + region: vi.fn(async () => 'us-east-1'), + useFipsEndpoint: vi.fn(async () => false), + }, + } + }), + ConverseStreamCommand, + ConverseCommand, + CountTokensCommand, + ValidationException: MockValidationException, + } +}) + +vi.mock('../../logging/warn-once.js', () => ({ + warnOnce: vi.fn(), +})) + +describe('BedrockModel', () => { + const BEDROCK_NOOP_TOOL_CONFIG = { + tools: [{ toolSpec: { ...NOOP_TOOL_SPEC, inputSchema: { json: NOOP_TOOL_SPEC.inputSchema } } }], + } + + beforeEach(() => { + vi.clearAllMocks() + // Reset mock to a working implementation to ensure test isolation + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { contentBlockStart: {} } + yield { contentBlockDelta: { delta: { text: 'Hello' } } } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'end_turn' } } + yield { metadata: { usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } } } + }) + // Clean up AWS_REGION env var in Node.js only + if (isNode && process.env) { + delete process.env.AWS_REGION + } + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe('constructor', () => { + it('creates an instance with default configuration', () => { + const provider = new BedrockModel() + const config = provider.getConfig() + expect(config.modelId).toBeDefined() + }) + + it('warns when modelId is not explicitly set', () => { + new BedrockModel() + expect(warnOnce).toHaveBeenCalledWith( + expect.objectContaining({ warn: expect.any(Function) }), + expect.stringContaining('using default modelId') + ) + }) + + it('does not warn when modelId is explicitly set', () => { + new BedrockModel({ modelId: 'us.anthropic.claude-3-5-sonnet-20241022-v2:0' }) + expect(warnOnce).not.toHaveBeenCalledWith( + expect.objectContaining({ warn: expect.any(Function) }), + expect.stringContaining('using default modelId') + ) + }) + + it('uses provided model ID ', () => { + const customModelId = 'us.anthropic.claude-3-5-sonnet-20241022-v2:0' + const provider = new BedrockModel({ modelId: customModelId }) + expect(provider.getConfig()).toStrictEqual({ + modelId: customModelId, + contextWindowLimit: 200_000, + }) + }) + + it('uses provided region', () => { + const customRegion = 'eu-west-1' + new BedrockModel({ region: customRegion }) + expect(BedrockRuntimeClient).toHaveBeenCalledWith({ + region: customRegion, + customUserAgent: 'strands-agents-ts-sdk', + requestHandler: { requestTimeout: 120_000 }, + }) + }) + + it('extends custom user agent if provided', () => { + const customAgent = 'my-app/1.0' + new BedrockModel({ region: 'us-west-2', clientConfig: { customUserAgent: customAgent } }) + expect(BedrockRuntimeClient).toHaveBeenCalledWith({ + region: 'us-west-2', + customUserAgent: 'my-app/1.0 strands-agents-ts-sdk', + requestHandler: { requestTimeout: 120_000 }, + }) + }) + + it('passes custom endpoint to client', () => { + const endpoint = 'https://vpce-abc.bedrock-runtime.us-west-2.vpce.amazonaws.com' + const region = 'us-west-2' + new BedrockModel({ region, clientConfig: { endpoint } }) + expect(BedrockRuntimeClient).toHaveBeenCalledWith({ + region, + endpoint, + customUserAgent: 'strands-agents-ts-sdk', + requestHandler: { requestTimeout: 120_000 }, + }) + }) + + it('passes custom credentials to client', () => { + const credentials = { + accessKeyId: 'AKIAIOSFODNN7EXAMPLE', + secretAccessKey: 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY', + } + const region = 'us-west-2' + new BedrockModel({ region, clientConfig: { credentials } }) + expect(BedrockRuntimeClient).toHaveBeenCalledWith({ + region, + credentials, + customUserAgent: 'strands-agents-ts-sdk', + requestHandler: { requestTimeout: 120_000 }, + }) + }) + + it('applies a default 120s request timeout', () => { + new BedrockModel({ region: 'us-west-2' }) + expect(BedrockRuntimeClient).toHaveBeenCalledWith( + expect.objectContaining({ requestHandler: { requestTimeout: 120_000 } }) + ) + }) + + it('lets the caller override requestTimeout', () => { + new BedrockModel({ region: 'us-west-2', clientConfig: { requestHandler: { requestTimeout: 5_000 } } }) + expect(BedrockRuntimeClient).toHaveBeenCalledWith( + expect.objectContaining({ requestHandler: { requestTimeout: 5_000 } }) + ) + }) + + it('merges the default timeout with other requestHandler options', () => { + new BedrockModel({ region: 'us-west-2', clientConfig: { requestHandler: { connectionTimeout: 1_000 } } }) + expect(BedrockRuntimeClient).toHaveBeenCalledWith( + expect.objectContaining({ requestHandler: { requestTimeout: 120_000, connectionTimeout: 1_000 } }) + ) + }) + + it('passes a user-provided handler instance through untouched', () => { + const handler = { handle: vi.fn(), updateHttpClientConfig: vi.fn(), httpHandlerConfigs: vi.fn() } + new BedrockModel({ region: 'us-west-2', clientConfig: { requestHandler: handler } }) + expect(BedrockRuntimeClient).toHaveBeenCalledWith(expect.objectContaining({ requestHandler: handler })) + }) + + it('adds api key middleware when apiKey is provided', () => { + const provider = new BedrockModel({ region: 'us-east-1', apiKey: 'br-test-key' }) + const mockAdd = provider['_client'].middlewareStack.add as ReturnType + expect(mockAdd).toHaveBeenCalledWith(expect.any(Function), { + step: 'finalizeRequest', + priority: 'low', + name: 'bedrockApiKeyMiddleware', + }) + }) + + it('does not add api key middleware when apiKey is not provided', () => { + const provider = new BedrockModel({ region: 'us-east-1' }) + const mockAdd = provider['_client'].middlewareStack.add as ReturnType + expect(mockAdd).not.toHaveBeenCalled() + }) + + it('api key middleware sets authorization header', async () => { + const provider = new BedrockModel({ region: 'us-east-1', apiKey: 'br-test-key' }) + const mockAdd = provider['_client'].middlewareStack.add as ReturnType + const middlewareFn = mockAdd.mock.calls[0]![0] as ( + next: (args: unknown) => Promise + ) => (args: unknown) => Promise + + const mockNext = vi.fn(async (args: unknown) => args) + const handler = middlewareFn(mockNext) + const args = { request: { headers: { authorization: 'AWS4-HMAC-SHA256 ...' } } } + await handler(args) + + expect(args.request.headers['authorization']).toBe('Bearer br-test-key') + expect(mockNext).toHaveBeenCalledWith(args) + }) + + it('does not include apiKey in model config', () => { + const provider = new BedrockModel({ region: 'us-east-1', apiKey: 'br-test-key', temperature: 0.5 }) + const config = provider.getConfig() + expect(config).toStrictEqual({ + modelId: 'global.anthropic.claude-sonnet-4-6', + temperature: 0.5, + contextWindowLimit: 1_000_000, + }) + }) + + it('includes contextWindowLimit in config when provided', () => { + const provider = new BedrockModel({ + modelId: 'anthropic.claude-sonnet-4-20250514-v1:0', + contextWindowLimit: 200_000, + }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'anthropic.claude-sonnet-4-20250514-v1:0', + contextWindowLimit: 200_000, + }) + }) + + it('auto-populates contextWindowLimit from model ID lookup', () => { + const provider = new BedrockModel({ modelId: 'anthropic.claude-sonnet-4-20250514-v1:0' }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'anthropic.claude-sonnet-4-20250514-v1:0', + contextWindowLimit: 1_000_000, + }) + }) + + it('auto-populates contextWindowLimit for cross-region model IDs', () => { + const provider = new BedrockModel({ modelId: 'us.anthropic.claude-sonnet-4-6' }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'us.anthropic.claude-sonnet-4-6', + contextWindowLimit: 1_000_000, + }) + }) + + it('auto-populates contextWindowLimit for default model ID', () => { + const provider = new BedrockModel() + expect(provider.getConfig()).toStrictEqual({ + modelId: 'global.anthropic.claude-sonnet-4-6', + contextWindowLimit: 1_000_000, + }) + }) + + it('does not override explicit contextWindowLimit', () => { + const provider = new BedrockModel({ + modelId: 'anthropic.claude-sonnet-4-20250514-v1:0', + contextWindowLimit: 100_000, + }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'anthropic.claude-sonnet-4-20250514-v1:0', + contextWindowLimit: 100_000, + }) + }) + + it('leaves contextWindowLimit undefined for unknown model IDs', () => { + const provider = new BedrockModel({ modelId: 'unknown.model-v1:0' }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'unknown.model-v1:0', + }) + }) + }) + + describe('updateConfig', () => { + it('merges new config with existing config', () => { + const provider = new BedrockModel({ region: 'us-west-2', temperature: 0.5 }) + provider.updateConfig({ temperature: 0.8, maxTokens: 2048 }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'global.anthropic.claude-sonnet-4-6', + temperature: 0.8, + maxTokens: 2048, + contextWindowLimit: 1_000_000, + }) + }) + + it('preserves fields not included in the update', () => { + const provider = new BedrockModel({ + region: 'us-west-2', + modelId: 'custom-model', + temperature: 0.5, + maxTokens: 1024, + }) + provider.updateConfig({ temperature: 0.8 }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'custom-model', + temperature: 0.8, + maxTokens: 1024, + }) + }) + + it('re-resolves contextWindowLimit when modelId changes and it was auto-resolved', () => { + const provider = new BedrockModel({ region: 'us-west-2' }) + expect(provider.getConfig().contextWindowLimit).toBe(1_000_000) + + provider.updateConfig({ modelId: 'anthropic.claude-haiku-4-5-20251001-v1:0' }) + expect(provider.getConfig().contextWindowLimit).toBe(200_000) + }) + + it('clears contextWindowLimit when modelId changes to unknown model', () => { + const provider = new BedrockModel({ region: 'us-west-2' }) + expect(provider.getConfig().contextWindowLimit).toBe(1_000_000) + + provider.updateConfig({ modelId: 'my-custom-finetuned-model' }) + expect(provider.getConfig().contextWindowLimit).toBeUndefined() + }) + + it('preserves explicit contextWindowLimit when modelId changes', () => { + const provider = new BedrockModel({ region: 'us-west-2', contextWindowLimit: 50_000 }) + expect(provider.getConfig().contextWindowLimit).toBe(50_000) + + provider.updateConfig({ modelId: 'anthropic.claude-haiku-4-5-20251001-v1:0' }) + expect(provider.getConfig().contextWindowLimit).toBe(50_000) + }) + }) + + describe('getConfig', () => { + it('returns the current configuration', () => { + const provider = new BedrockModel({ + region: 'us-west-2', + modelId: 'test-model', + maxTokens: 1024, + temperature: 0.7, + }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'test-model', + maxTokens: 1024, + temperature: 0.7, + }) + }) + }) + + describe('format_message', async () => { + const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) + it('formats the request to bedrock properly', async () => { + const provider = new BedrockModel({ + region: 'us-west-2', + modelId: 'anthropic.claude-test-model', + maxTokens: 1024, + temperature: 0.7, + topP: 0.9, + stopSequences: ['STOP'], + cacheConfig: { strategy: 'auto' }, + additionalResponseFieldPaths: ['Hello!'], + additionalRequestFields: ['World!'], + additionalArgs: { + MyExtraArg: 'ExtraArg', + }, + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + const options: StreamOptions = { + systemPrompt: 'You are a helpful assistant', + toolSpecs: [ + { + name: 'calculator', + description: 'Perform calculations', + inputSchema: { type: 'object', properties: { expression: { type: 'string' } } }, + }, + ], + toolChoice: { auto: {} }, + } + + // Trigger the stream to make the request, but ignore the events for now + collectIterator(provider.stream(messages, options)) + + // Verify ConverseStreamCommand was called with properly formatted request + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + MyExtraArg: 'ExtraArg', + additionalModelRequestFields: ['World!'], + additionalModelResponseFieldPaths: ['Hello!'], + modelId: 'anthropic.claude-test-model', + messages: [ + { + role: 'user', + content: [{ text: 'Hello' }, { cachePoint: { type: 'default' } }], + }, + ], + system: [{ text: 'You are a helpful assistant' }], + toolConfig: { + toolChoice: { auto: {} }, + tools: [ + { + toolSpec: { + name: 'calculator', + description: 'Perform calculations', + inputSchema: { json: { type: 'object', properties: { expression: { type: 'string' } } } }, + }, + }, + { cachePoint: { type: 'default' } }, + ], + }, + inferenceConfig: { + maxTokens: 1024, + temperature: 0.7, + topP: 0.9, + stopSequences: ['STOP'], + }, + }) + }) + + it('formats tool use messages', async () => { + const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) + const provider = new BedrockModel() + const messages = [ + new Message({ + role: 'assistant', + content: [ + new ToolUseBlock({ + name: 'calculator', + toolUseId: 'tool-123', + input: { a: 5, b: 3 }, + }), + ], + }), + ] + + // Run the stream but ignore the output + collectIterator(provider.stream(messages)) + + // Verify ConverseStreamCommand was called with properly formatted request + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: expect.arrayContaining([ + expect.objectContaining({ + role: 'assistant', + content: expect.arrayContaining([ + expect.objectContaining({ + toolUse: expect.objectContaining({ + name: 'calculator', + toolUseId: 'tool-123', + input: { a: 5, b: 3 }, + }), + }), + ]), + }), + ]), + }) + ) + }) + + it('formats tool result messages', async () => { + const provider = new BedrockModel() + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-123', + status: 'success', + content: [new TextBlock('Result: 8'), new JsonBlock({ json: { hello: 'world' } })], + }), + ], + }), + ] + + // Start the stream + collectIterator(provider.stream(messages)) + + // Verify ConverseStreamCommand was called with properly formatted request + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + messages: [ + { + content: [ + { + toolResult: { + content: [ + { + text: 'Result: 8', + }, + { + json: { + hello: 'world', + }, + }, + ], + status: 'success', + toolUseId: 'tool-123', + }, + }, + ], + role: 'user', + }, + ], + toolConfig: BEDROCK_NOOP_TOOL_CONFIG, + modelId: expect.any(String), + }) + }) + + it('injects noop tool config when messages have tool blocks but no toolSpecs', async () => { + const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) + const provider = new BedrockModel() + const messages = [ + new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'calc', toolUseId: 'id-1', input: { a: 1 } })], + }), + new Message({ + role: 'user', + content: [new ToolResultBlock({ toolUseId: 'id-1', status: 'success', content: [new TextBlock('42')] })], + }), + new Message({ role: 'user', content: [new TextBlock('Summarize')] }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + toolConfig: BEDROCK_NOOP_TOOL_CONFIG, + }) + ) + }) + + it('does not inject noop tool config when messages have no tool blocks', async () => { + const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + collectIterator(provider.stream(messages)) + + const call = mockConverseStreamCommand.mock.calls[0]![0] as unknown as Record + expect(call.toolConfig).toBeUndefined() + }) + + it('does not inject noop tool config when toolSpecs are provided', async () => { + const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) + const provider = new BedrockModel() + const messages = [ + new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'calc', toolUseId: 'id-1', input: {} })], + }), + new Message({ + role: 'user', + content: [new ToolResultBlock({ toolUseId: 'id-1', status: 'success', content: [new TextBlock('ok')] })], + }), + ] + + const options: StreamOptions = { + toolSpecs: [{ name: 'calc', description: 'Calculator', inputSchema: { type: 'object', properties: {} } }], + } + collectIterator(provider.stream(messages, options)) + + const call = mockConverseStreamCommand.mock.calls[0]![0] as unknown as Record + const toolConfig = call.toolConfig as { tools: Array<{ toolSpec?: { name: string } }> } + expect(toolConfig.tools[0]!.toolSpec!.name).toBe('calc') + expect(toolConfig.tools.length).toBe(1) + }) + + it('formats reasoning messages properly', async () => { + const provider = new BedrockModel() + const messages = [ + new Message({ + role: 'user', + content: [ + new ReasoningBlock({ + text: 'Hello', + signature: 'World', + }), + new ReasoningBlock({ + redactedContent: new Uint8Array(1), + }), + ], + }), + ] + + // Start the stream but don't await it + collectIterator(provider.stream(messages)) + + // Verify ConverseStreamCommand was called with properly formatted request + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + messages: [ + { + role: 'user', + content: [ + { + reasoningContent: { + reasoningText: { + signature: 'World', + text: 'Hello', + }, + }, + }, + { + reasoningContent: { + redactedContent: new Uint8Array(1), + }, + }, + ], + }, + ], + modelId: expect.any(String), + }) + }) + + it('formats cache point blocks in messages', async () => { + const provider = new BedrockModel() + const messages = [ + new Message({ + role: 'user', + content: [new TextBlock('Message with cache point'), new CachePointBlock({ cacheType: 'default' })], + }), + ] + + collectIterator(provider.stream(messages)) + + // Verify ConverseStreamCommand was called with properly formatted request + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + messages: [ + { + role: 'user', + content: [{ text: 'Message with cache point' }, { cachePoint: { type: 'default' } }], + }, + ], + modelId: expect.any(String), + }) + }) + + it('preserves ttl on user-supplied cache point blocks in messages', async () => { + const provider = new BedrockModel() + const messages = [ + new Message({ + role: 'user', + content: [ + new TextBlock('Message with 1h cache point'), + new CachePointBlock({ cacheType: 'default', ttl: '1h' }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + messages: [ + { + role: 'user', + content: [{ text: 'Message with 1h cache point' }, { cachePoint: { type: 'default', ttl: '1h' } }], + }, + ], + modelId: expect.any(String), + }) + }) + + it('preserves ttl on cache point blocks in system prompt', async () => { + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const options: StreamOptions = { + systemPrompt: [ + new TextBlock('You are a helpful assistant'), + new CachePointBlock({ cacheType: 'default', ttl: '5m' }), + ], + } + + collectIterator(provider.stream(messages, options)) + + const call = mockConverseStreamCommand.mock.lastCall?.[0] + expect(call?.system).toStrictEqual([ + { text: 'You are a helpful assistant' }, + { cachePoint: { type: 'default', ttl: '5m' } }, + ]) + }) + + it('forwards arbitrary ttl strings without client-side validation (Bedrock validates server-side)', async () => { + const provider = new BedrockModel() + const messages = [ + new Message({ + role: 'user', + content: [new TextBlock('Hello'), new CachePointBlock({ cacheType: 'default', ttl: '2h' })], + }), + ] + + collectIterator(provider.stream(messages)) + + const call = mockConverseStreamCommand.mock.lastCall?.[0] + const userMsg = call?.messages?.[0] + const lastBlock = userMsg?.content?.[userMsg.content.length - 1] + expect(lastBlock).toStrictEqual({ cachePoint: { type: 'default', ttl: '2h' } }) + }) + }) + + describe.each([ + { mode: 'streaming', stream: true }, + { mode: 'non-streaming', stream: false }, + ])('BedrockModel in $mode mode', ({ stream }) => { + it('yields and validates text events correctly', async () => { + const mockSend = vi.fn(async () => { + if (stream) { + return { + stream: (async function* (): AsyncGenerator { + yield { messageStart: { role: 'assistant' } } + yield { contentBlockStart: {} } + yield { contentBlockDelta: { delta: { text: 'Hello' } } } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'end_turn' } } + yield { + metadata: { usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, metrics: { latencyMs: 100 } }, + } + })(), + } + } else { + return { + output: { message: { role: 'assistant', content: [{ text: 'Hello' }] } }, + stopReason: 'end_turn', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + metrics: { latencyMs: 100 }, + } + } + }) + + mockBedrockClientImplementation({ send: mockSend }) + + const provider = new BedrockModel({ stream }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const events = await collectIterator(provider.stream(messages)) + + expect(events).toContainEqual({ role: 'assistant', type: 'modelMessageStartEvent' }) + expect(events).toContainEqual({ type: 'modelContentBlockStartEvent' }) + expect(events).toContainEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'Hello' }, + }) + expect(events).toContainEqual({ type: 'modelContentBlockStopEvent' }) + expect(events).toContainEqual({ type: 'modelMessageStopEvent', stopReason: 'endTurn' }) + expect(events).toContainEqual({ + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + metrics: { latencyMs: 100 }, + }) + }) + + it('yields and validates toolUse events correctly', async () => { + const mockSend = vi.fn(async () => { + if (stream) { + return { + stream: (async function* (): AsyncGenerator { + yield { messageStart: { role: 'assistant' } } + yield { + contentBlockStart: { + start: { toolUse: { toolUseId: 'tool-use-123', name: 'get_weather' } }, + }, + } + yield { + contentBlockDelta: { + delta: { toolUse: { input: '{"location":"San Francisco"}' } }, + }, + } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'tool_use' } } + yield { + metadata: { + usage: { inputTokens: 10, outputTokens: 25, totalTokens: 35 }, + metrics: { latencyMs: 120 }, + }, + } + })(), + } + } else { + return { + output: { + message: { + role: 'assistant', + content: [ + { toolUse: { toolUseId: 'tool-use-123', name: 'get_weather', input: { location: 'San Francisco' } } }, + ], + }, + }, + stopReason: 'tool_use', + usage: { inputTokens: 10, outputTokens: 25, totalTokens: 35 }, + metrics: { latencyMs: 120 }, + } + } + }) + mockBedrockClientImplementation({ send: mockSend }) + + const provider = new BedrockModel({ stream }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Weather?')] })] + const events = await collectIterator(provider.stream(messages)) + const startEvent = events.find((e) => e.type === 'modelContentBlockStartEvent') + const inputDeltaEvent = events.find( + (e) => e.type === 'modelContentBlockDeltaEvent' && e.delta.type === 'toolUseInputDelta' + ) + + expect(events).toContainEqual({ role: 'assistant', type: 'modelMessageStartEvent' }) + expect(startEvent).toStrictEqual({ + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', name: 'get_weather', toolUseId: 'tool-use-123' }, + }) + expect(inputDeltaEvent).toStrictEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: '{"location":"San Francisco"}' }, + }) + expect(events).toContainEqual({ type: 'modelContentBlockStopEvent' }) + expect(events).toContainEqual({ stopReason: 'toolUse', type: 'modelMessageStopEvent' }) + expect(events).toContainEqual({ + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 25, totalTokens: 35 }, + metrics: { latencyMs: 120 }, + }) + }) + + it('yields and validates reasoningText events correctly', async () => { + const mockSend = vi.fn(async () => { + if (stream) { + return { + stream: (async function* (): AsyncGenerator { + yield { messageStart: { role: 'assistant' } } + yield { contentBlockStart: {} } + yield { + contentBlockDelta: { delta: { reasoningContent: { text: 'Thinking...' } } }, + } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'end_turn' } } + yield { + metadata: { + usage: { inputTokens: 15, outputTokens: 30, totalTokens: 45 }, + metrics: { latencyMs: 150 }, + }, + } + })(), + } + } else { + return { + output: { + message: { + role: 'assistant', + content: [{ reasoningContent: { reasoningText: { text: 'Thinking...' } } }], + }, + }, + stopReason: 'end_turn', + usage: { inputTokens: 15, outputTokens: 30, totalTokens: 45 }, + metrics: { latencyMs: 150 }, + } + } + }) + mockBedrockClientImplementation({ send: mockSend }) + + const provider = new BedrockModel({ stream }) + const messages = [new Message({ role: 'user', content: [new TextBlock('A question.')] })] + const events = await collectIterator(provider.stream(messages)) + + expect(events).toContainEqual({ role: 'assistant', type: 'modelMessageStartEvent' }) + expect(events).toContainEqual({ type: 'modelContentBlockStartEvent' }) + expect(events).toContainEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'reasoningContentDelta', text: 'Thinking...' }, + }) + expect(events).toContainEqual({ type: 'modelContentBlockStopEvent' }) + expect(events).toContainEqual({ stopReason: 'endTurn', type: 'modelMessageStopEvent' }) + expect(events).toContainEqual({ + type: 'modelMetadataEvent', + usage: { inputTokens: 15, outputTokens: 30, totalTokens: 45 }, + metrics: { latencyMs: 150 }, + }) + }) + + it('yields and validates redactedContent events correctly', async () => { + const redactedBytes = new Uint8Array([1, 2, 3]) + + const mockSend = vi.fn(async () => { + if (stream) { + return { + stream: (async function* (): AsyncGenerator { + yield { messageStart: { role: 'assistant' } } + yield { contentBlockStart: {} } + yield { + contentBlockDelta: { + delta: { reasoningContent: { redactedContent: redactedBytes } }, + }, + } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'end_turn' } } + yield { + metadata: { usage: { inputTokens: 15, outputTokens: 5, totalTokens: 20 }, metrics: { latencyMs: 110 } }, + } + })(), + } + } else { + return { + output: { + message: { + role: 'assistant', + content: [{ reasoningContent: { redactedContent: redactedBytes } }], + }, + }, + stopReason: 'end_turn', + usage: { inputTokens: 15, outputTokens: 5, totalTokens: 20 }, + metrics: { latencyMs: 110 }, + } + } + }) + mockBedrockClientImplementation({ send: mockSend }) + + const provider = new BedrockModel({ stream }) + const messages = [new Message({ role: 'user', content: [new TextBlock('A sensitive question.')] })] + const events = await collectIterator(provider.stream(messages)) + + expect(events).toContainEqual({ role: 'assistant', type: 'modelMessageStartEvent' }) + expect(events).toContainEqual({ type: 'modelContentBlockStartEvent' }) + expect(events).toContainEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'reasoningContentDelta', redactedContent: redactedBytes }, + }) + expect(events).toContainEqual({ type: 'modelContentBlockStopEvent' }) + expect(events).toContainEqual({ stopReason: 'endTurn', type: 'modelMessageStopEvent' }) + expect(events).toContainEqual({ + type: 'modelMetadataEvent', + usage: { inputTokens: 15, outputTokens: 5, totalTokens: 20 }, + metrics: { latencyMs: 110 }, + }) + }) + + it('yields and validates citation events correctly', async () => { + // Bedrock streaming sends individual citation deltas with key 'citation' + const bedrockCitationDelta = { + location: { documentChar: { documentIndex: 0, start: 10, end: 50 } }, + sourceContent: [{ text: 'source text' }], + source: 'doc-0', + title: 'Test Doc', + } + + // Bedrock non-streaming wire format uses object-key discrimination + const bedrockCitationsData = { + citations: [bedrockCitationDelta], + content: [{ text: 'generated text' }], + } + + const mockSend = vi.fn(async () => { + if (stream) { + return { + stream: (async function* (): AsyncGenerator { + yield { messageStart: { role: 'assistant' } } + yield { contentBlockStart: {} } + yield { + contentBlockDelta: { + delta: { citation: bedrockCitationDelta }, + }, + } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'end_turn' } } + yield { + metadata: { usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, metrics: { latencyMs: 100 } }, + } + })(), + } + } else { + return { + output: { + message: { + role: 'assistant', + content: [{ citationsContent: bedrockCitationsData }], + }, + }, + stopReason: 'end_turn', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + metrics: { latencyMs: 100 }, + } + } + }) + mockBedrockClientImplementation({ send: mockSend }) + + const provider = new BedrockModel({ stream }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Cite this.')] })] + const events = await collectIterator(provider.stream(messages)) + + // SDK events should use type-field discrimination + expect(events).toContainEqual({ role: 'assistant', type: 'modelMessageStartEvent' }) + expect(events).toContainEqual({ type: 'modelContentBlockStartEvent' }) + expect(events).toContainEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { + type: 'citationsDelta', + citations: [ + { + location: { type: 'documentChar', documentIndex: 0, start: 10, end: 50 }, + sourceContent: [{ text: 'source text' }], + source: 'doc-0', + title: 'Test Doc', + }, + ], + content: stream ? [] : [{ text: 'generated text' }], + }, + }) + expect(events).toContainEqual({ type: 'modelContentBlockStopEvent' }) + expect(events).toContainEqual({ stopReason: 'endTurn', type: 'modelMessageStopEvent' }) + }) + + describe('error handling', async () => { + it.each([ + { + name: 'ContextWindowOverflowError for context overflow', + error: new Error('Input is too long for requested model'), + expected: ContextWindowOverflowError, + }, + { + name: 'ValidationException for invalid input', + error: new ValidationException({ message: 'ValidationException', $metadata: {} }), + expected: ValidationException, + }, + ])('throws $name', async ({ error, expected }) => { + vi.clearAllMocks() + const mockSendError = vi.fn().mockRejectedValue(error) + mockBedrockClientImplementation({ send: mockSendError }) + + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + await expect(collectIterator(provider.stream(messages))).rejects.toThrow(expected) + }) + }) + }) + + describe('stream', () => { + it('handles tool use input delta', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { + contentBlockStart: { start: { toolUse: { name: 'calc', toolUseId: 'id' } } }, + } + yield { contentBlockDelta: { delta: { toolUse: { input: '{"a": 1}' } } } } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'tool_use' } } + yield { metadata: { usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } } } + }) + + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + const events = await collectIterator(provider.stream(messages)) + + expect(events).toContainEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { + type: 'toolUseInputDelta', + input: '{"a": 1}', + }, + }) + }) + + it('handles reasoning content delta with both text and signature, as well as redactedContent', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { contentBlockStart: {} } + yield { + contentBlockDelta: { + delta: { reasoningContent: { text: 'thinking...', signature: 'sig123' } }, + }, + } + yield { + contentBlockDelta: { + delta: { reasoningContent: { redactedContent: new Uint8Array(1) } }, + }, + } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'end_turn' } } + yield { metadata: { usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } } } + }) + + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + const events = await collectIterator(provider.stream(messages)) + + expect(events).toContainEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { + type: 'reasoningContentDelta', + text: 'thinking...', + signature: 'sig123', + }, + }) + expect(events).toContainEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { + type: 'reasoningContentDelta', + redactedContent: new Uint8Array(1), + }, + }) + }) + + it('handles reasoning content delta with only text, skips unsupported types', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { contentBlockStart: {} } + yield { + contentBlockDelta: { + delta: { reasoningContent: { text: 'thinking...' } }, + }, + } + yield { + contentBlockDelta: { + delta: { unknown: 'type' }, + }, + } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'end_turn' } } + yield { metadata: { usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } } } + yield { unknown: 'type' } + }) + + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + const events = await collectIterator(provider.stream(messages)) + + const reasoningDelta = events.find( + (e) => e.type === 'modelContentBlockDeltaEvent' && e.delta.type === 'reasoningContentDelta' + ) + expect(reasoningDelta).toBeDefined() + if ( + reasoningDelta?.type === 'modelContentBlockDeltaEvent' && + reasoningDelta.delta.type === 'reasoningContentDelta' + ) { + expect(reasoningDelta.delta.text).toBe('thinking...') + expect(reasoningDelta.delta.signature).toBeUndefined() + } + }) + + it('handles reasoning content delta with only signature', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { contentBlockStart: {} } + yield { + contentBlockDelta: { + delta: { reasoningContent: { signature: 'sig123' } }, + }, + } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'end_turn' } } + yield { metadata: { usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } } } + }) + + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + const events = await collectIterator(provider.stream(messages)) + + const reasoningDelta = events.find( + (e) => e.type === 'modelContentBlockDeltaEvent' && e.delta.type === 'reasoningContentDelta' + ) + expect(reasoningDelta).toBeDefined() + if ( + reasoningDelta?.type === 'modelContentBlockDeltaEvent' && + reasoningDelta.delta.type === 'reasoningContentDelta' + ) { + expect(reasoningDelta.delta.text).toBeUndefined() + expect(reasoningDelta.delta.signature).toBe('sig123') + } + }) + + it('handles cache usage metrics', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { contentBlockStart: {} } + yield { contentBlockDelta: { delta: { text: 'Hello' } } } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'end_turn' } } + yield { + metadata: { + usage: { + inputTokens: 100, + outputTokens: 50, + totalTokens: 150, + cacheReadInputTokens: 80, + cacheWriteInputTokens: 20, + }, + }, + } + }) + + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + const events = await collectIterator(provider.stream(messages)) + + const metadataEvent = events.find((e) => e.type === 'modelMetadataEvent') + expect(metadataEvent).toBeDefined() + if (metadataEvent?.type === 'modelMetadataEvent') { + expect(metadataEvent.usage?.cacheReadInputTokens).toBe(80) + expect(metadataEvent.usage?.cacheWriteInputTokens).toBe(20) + } + }) + + it('handles trace in metadata', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { contentBlockStart: {} } + yield { contentBlockDelta: { delta: { text: 'Hello' } } } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'end_turn' } } + yield { + metadata: { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + trace: { guardrail: { action: 'INTERVENED' } }, + }, + } + }) + + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + const events = await collectIterator(provider.stream(messages)) + + const metadataEvent = events.find((e) => e.type === 'modelMetadataEvent') + expect(metadataEvent).toBeDefined() + if (metadataEvent?.type === 'modelMetadataEvent') { + expect(metadataEvent.trace).toBeDefined() + } + }) + + it('handles additionalModelResponseFields', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { contentBlockStart: {} } + yield { contentBlockDelta: { delta: { text: 'Hello' } } } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'end_turn', additionalModelResponseFields: { customField: 'value' } } } + yield { metadata: { usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } } } + }) + + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + const events = await collectIterator(provider.stream(messages)) + + const stopEvent = events.find((e) => e.type === 'modelMessageStopEvent') + expect(stopEvent).toBeDefined() + if (stopEvent?.type === 'modelMessageStopEvent') { + expect(stopEvent.additionalModelResponseFields).toStrictEqual({ customField: 'value' }) + } + }) + + describe('handles all stop reason types', () => { + const stopReasons = [ + ['end_turn', 'endTurn'], + ['tool_use', 'toolUse'], + ['max_tokens', 'maxTokens'], + ['stop_sequence', 'stopSequence'], + ['content_filtered', 'contentFiltered'], + ['guardrail_intervened', 'guardrailIntervened'], + ['model_context_window_exceeded', 'modelContextWindowExceeded'], + ['new_stop_reason', 'newStopReason'], + ] + for (const [bedrockReason, expectedReason] of stopReasons) { + it(`handles ${bedrockReason} stop reason types`, async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { messageStop: { stopReason: bedrockReason } } + yield { metadata: { usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } } } + }) + + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + const events = [] + for await (const event of provider.stream(messages)) { + events.push(event) + } + + const stopEvent = events.find((e) => e.type === 'modelMessageStopEvent') + expect(stopEvent).toBeDefined() + if (stopEvent?.type === 'modelMessageStopEvent') { + expect(stopEvent.stopReason).toBe(expectedReason) + } + }) + } + }) + + describe('throttling', () => { + it('throws ModelThrottledError when throttlingException is received', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { throttlingException: { message: 'Rate exceeded' } } + }) + + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + await expect(async () => { + for await (const _ of provider.stream(messages)) { + // consume stream + } + }).rejects.toThrow(ModelThrottledError) + }) + + it('includes throttling message in ModelThrottledError', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { throttlingException: { message: 'Too many requests' } } + }) + + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + await expect(async () => { + for await (const _ of provider.stream(messages)) { + // consume stream + } + }).rejects.toThrow('Too many requests') + }) + + it('uses default message when throttlingException has no message', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { throttlingException: {} } + }) + + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + await expect(async () => { + for await (const _ of provider.stream(messages)) { + // consume stream + } + }).rejects.toThrow('Request was throttled by the model provider') + }) + }) + }) + + describe('system prompt formatting', async () => { + const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) + + beforeEach(() => { + vi.clearAllMocks() + }) + + it('does not add cache points to string system prompt with cacheConfig', async () => { + const provider = new BedrockModel({ cacheConfig: { strategy: 'auto' } }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const options: StreamOptions = { + systemPrompt: 'You are a helpful assistant', + } + + collectIterator(provider.stream(messages, options)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + modelId: 'global.anthropic.claude-sonnet-4-6', + messages: [ + { + role: 'user', + content: [{ text: 'Hello' }, { cachePoint: { type: 'default' } }], + }, + ], + system: [{ text: 'You are a helpful assistant' }], + }) + }) + + it('formats array system prompt with text blocks only', async () => { + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const options: StreamOptions = { + systemPrompt: [ + { type: 'textBlock', text: 'You are a helpful assistant' }, + { type: 'textBlock', text: 'Additional context here' }, + ] as SystemContentBlock[], + } + + collectIterator(provider.stream(messages, options)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + modelId: 'global.anthropic.claude-sonnet-4-6', + messages: [ + { + role: 'user', + content: [{ text: 'Hello' }], + }, + ], + system: [{ text: 'You are a helpful assistant' }, { text: 'Additional context here' }], + }) + }) + + it('formats array system prompt with cache points', async () => { + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const options: StreamOptions = { + systemPrompt: [ + { type: 'textBlock', text: 'You are a helpful assistant' }, + { type: 'textBlock', text: 'Large context document' }, + { type: 'cachePointBlock', cacheType: 'default' }, + ] as SystemContentBlock[], + } + + collectIterator(provider.stream(messages, options)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + modelId: 'global.anthropic.claude-sonnet-4-6', + messages: [ + { + role: 'user', + content: [{ text: 'Hello' }], + }, + ], + system: [ + { text: 'You are a helpful assistant' }, + { text: 'Large context document' }, + { cachePoint: { type: 'default' } }, + ], + }) + }) + + it('does not warn when array system prompt is provided without cacheConfig', async () => { + const provider = new BedrockModel() + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const options: StreamOptions = { + systemPrompt: [ + { type: 'textBlock', text: 'You are a helpful assistant' }, + { type: 'cachePointBlock', cacheType: 'default' }, + ] as SystemContentBlock[], + } + + collectIterator(provider.stream(messages, options)) + + // Verify no warning was logged + expect(warnSpy).not.toHaveBeenCalled() + + // Verify array is used as-is + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + modelId: 'global.anthropic.claude-sonnet-4-6', + messages: [ + { + role: 'user', + content: [{ text: 'Hello' }], + }, + ], + system: [{ text: 'You are a helpful assistant' }, { cachePoint: { type: 'default' } }], + }) + + warnSpy.mockRestore() + }) + + it('adds cache point after tools when cacheConfig enabled', async () => { + const provider = new BedrockModel({ cacheConfig: { strategy: 'auto' } }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const options: StreamOptions = { + toolSpecs: [ + { + name: 'calculator', + description: 'Calculate', + inputSchema: { type: 'object' }, + }, + ], + } + + collectIterator(provider.stream(messages, options)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + modelId: 'global.anthropic.claude-sonnet-4-6', + messages: [ + { + role: 'user', + content: [{ text: 'Hello' }, { cachePoint: { type: 'default' } }], + }, + ], + toolConfig: { + tools: [ + { + toolSpec: { + name: 'calculator', + description: 'Calculate', + inputSchema: { json: { type: 'object' } }, + }, + }, + { cachePoint: { type: 'default' } }, + ], + }, + }) + }) + + it('adds cache points to tools and messages when cacheConfig enabled', async () => { + const provider = new BedrockModel({ cacheConfig: { strategy: 'auto' } }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Hello')] }), + new Message({ role: 'assistant', content: [new TextBlock('Hi')] }), + ] + const options: StreamOptions = { + systemPrompt: 'You are a helpful assistant', + toolSpecs: [ + { + name: 'calculator', + description: 'Calculate', + inputSchema: { type: 'object' }, + }, + ], + } + + collectIterator(provider.stream(messages, options)) + + const call = mockConverseStreamCommand.mock.lastCall?.[0] + expect(call?.system).toStrictEqual([{ text: 'You are a helpful assistant' }]) + expect(call?.toolConfig?.tools).toStrictEqual([ + { + toolSpec: { + name: 'calculator', + description: 'Calculate', + inputSchema: { json: { type: 'object' } }, + }, + }, + { cachePoint: { type: 'default' } }, + ]) + const userMsg = call?.messages?.[0] + const lastBlock = userMsg?.content?.[userMsg.content.length - 1] + expect(lastBlock).toStrictEqual({ cachePoint: { type: 'default' } }) + const assistantMsg = call?.messages?.[1] + const assistantLastBlock = assistantMsg?.content?.[assistantMsg.content.length - 1] + expect(assistantLastBlock).not.toStrictEqual({ cachePoint: { type: 'default' } }) + }) + + it('propagates cacheConfig ttls independently to tools and last user message', async () => { + const provider = new BedrockModel({ + cacheConfig: { strategy: 'auto', toolsTTL: '1h', messagesTTL: '5m' }, + }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const options: StreamOptions = { + toolSpecs: [ + { + name: 'calculator', + description: 'Calculate', + inputSchema: { type: 'object' }, + }, + ], + } + + collectIterator(provider.stream(messages, options)) + + const call = mockConverseStreamCommand.mock.lastCall?.[0] + expect(call?.toolConfig?.tools).toStrictEqual([ + { + toolSpec: { + name: 'calculator', + description: 'Calculate', + inputSchema: { json: { type: 'object' } }, + }, + }, + { cachePoint: { type: 'default', ttl: '1h' } }, + ]) + const userMsg = call?.messages?.[0] + const lastBlock = userMsg?.content?.[userMsg.content.length - 1] + expect(lastBlock).toStrictEqual({ cachePoint: { type: 'default', ttl: '5m' } }) + }) + + it('propagates only toolsTTL when messagesTTL is not set', async () => { + const provider = new BedrockModel({ cacheConfig: { strategy: 'auto', toolsTTL: '1h' } }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const options: StreamOptions = { + toolSpecs: [ + { + name: 'calculator', + description: 'Calculate', + inputSchema: { type: 'object' }, + }, + ], + } + + collectIterator(provider.stream(messages, options)) + + const call = mockConverseStreamCommand.mock.lastCall?.[0] + const toolsLast = call?.toolConfig?.tools?.[call.toolConfig.tools.length - 1] + expect(toolsLast).toStrictEqual({ cachePoint: { type: 'default', ttl: '1h' } }) + const userMsg = call?.messages?.[0] + const lastBlock = userMsg?.content?.[userMsg.content.length - 1] + expect(lastBlock).toStrictEqual({ cachePoint: { type: 'default' } }) + }) + + it('propagates only messagesTTL when toolsTTL is not set', async () => { + const provider = new BedrockModel({ cacheConfig: { strategy: 'auto', messagesTTL: '1h' } }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const options: StreamOptions = { + toolSpecs: [ + { + name: 'calculator', + description: 'Calculate', + inputSchema: { type: 'object' }, + }, + ], + } + + collectIterator(provider.stream(messages, options)) + + const call = mockConverseStreamCommand.mock.lastCall?.[0] + const toolsLast = call?.toolConfig?.tools?.[call.toolConfig.tools.length - 1] + expect(toolsLast).toStrictEqual({ cachePoint: { type: 'default' } }) + const userMsg = call?.messages?.[0] + const lastBlock = userMsg?.content?.[userMsg.content.length - 1] + expect(lastBlock).toStrictEqual({ cachePoint: { type: 'default', ttl: '1h' } }) + }) + + it('omits ttl on auto-injected cache points when no ttl is set', async () => { + const provider = new BedrockModel({ cacheConfig: { strategy: 'auto' } }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const options: StreamOptions = { + toolSpecs: [ + { + name: 'calculator', + description: 'Calculate', + inputSchema: { type: 'object' }, + }, + ], + } + + collectIterator(provider.stream(messages, options)) + + const call = mockConverseStreamCommand.mock.lastCall?.[0] + const toolsLast = call?.toolConfig?.tools?.[call.toolConfig.tools.length - 1] + expect(toolsLast).toStrictEqual({ cachePoint: { type: 'default' } }) + const userMsg = call?.messages?.[0] + const lastBlock = userMsg?.content?.[userMsg.content.length - 1] + expect(lastBlock).toStrictEqual({ cachePoint: { type: 'default' } }) + }) + + it('does not mutate the original messages array', async () => { + const provider = new BedrockModel({ cacheConfig: { strategy: 'auto' } }) + const originalMessages = [ + new Message({ role: 'user', content: [new TextBlock('Hello')] }), + new Message({ role: 'assistant', content: [new TextBlock('Hi')] }), + ] + + // Create a deep copy to compare against + const messagesCopy = JSON.parse(JSON.stringify(originalMessages)) + + collectIterator(provider.stream(originalMessages)) + + // Verify original messages are unchanged + expect(JSON.stringify(originalMessages)).toBe(JSON.stringify(messagesCopy)) + }) + + it('logs warning and disables caching for non-caching models', async () => { + const warnSpy = vi.spyOn(console, 'warn') + const provider = new BedrockModel({ + modelId: 'amazon.titan-text-express-v1', + cacheConfig: { strategy: 'auto' }, + }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const options: StreamOptions = { + systemPrompt: 'You are a helpful assistant', + } + + collectIterator(provider.stream(messages, options)) + + // Verify warning was logged + expect(warnSpy).toHaveBeenCalled() + + // Verify no cache points were added + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + modelId: 'amazon.titan-text-express-v1', + messages: [ + { + role: 'user', + content: [{ text: 'Hello' }], + }, + ], + system: [{ text: 'You are a helpful assistant' }], + }) + + warnSpy.mockRestore() + }) + + it('enables caching with anthropic strategy for application inference profiles', async () => { + const provider = new BedrockModel({ + modelId: 'arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/abc123', + cacheConfig: { strategy: 'anthropic' }, + }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Hello')] }), + new Message({ role: 'assistant', content: [new TextBlock('Hi')] }), + ] + + collectIterator(provider.stream(messages)) + + const call = mockConverseStreamCommand.mock.lastCall?.[0] + // Cache point should be on the user message (index 0) + const userMsg = call?.messages?.[0] + const lastBlock = userMsg?.content?.[userMsg.content.length - 1] + expect(lastBlock).toStrictEqual({ cachePoint: { type: 'default' } }) + }) + + it('handles empty array system prompt', async () => { + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const options: StreamOptions = { + systemPrompt: [], + } + + collectIterator(provider.stream(messages, options)) + + // Empty array should not set system field + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + modelId: 'global.anthropic.claude-sonnet-4-6', + messages: [ + { + role: 'user', + content: [{ text: 'Hello' }], + }, + ], + }) + }) + + it('formats array system prompt with guard content', async () => { + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const options: StreamOptions = { + systemPrompt: [ + new TextBlock('You are a helpful assistant'), + new GuardContentBlock({ + text: { + qualifiers: ['grounding_source'], + text: 'This content should be evaluated for grounding.', + }, + }), + ], + } + + collectIterator(provider.stream(messages, options)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + modelId: 'global.anthropic.claude-sonnet-4-6', + messages: [ + { + role: 'user', + content: [{ text: 'Hello' }], + }, + ], + system: [ + { text: 'You are a helpful assistant' }, + { + guardContent: { + text: { + text: 'This content should be evaluated for grounding.', + qualifiers: ['grounding_source'], + }, + }, + }, + ], + }) + }) + + it('formats mixed system prompt with text, guard content, and cache points', async () => { + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const options: StreamOptions = { + systemPrompt: [ + new TextBlock('You are a helpful assistant'), + new GuardContentBlock({ + text: { + qualifiers: ['grounding_source', 'query'], + text: 'Guard content', + }, + }), + new TextBlock('Additional context'), + new CachePointBlock({ cacheType: 'default' }), + ], + } + + collectIterator(provider.stream(messages, options)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + modelId: 'global.anthropic.claude-sonnet-4-6', + messages: [ + { + role: 'user', + content: [{ text: 'Hello' }], + }, + ], + system: [ + { text: 'You are a helpful assistant' }, + { + guardContent: { + text: { + text: 'Guard content', + qualifiers: ['grounding_source', 'query'], + }, + }, + }, + { text: 'Additional context' }, + { cachePoint: { type: 'default' } }, + ], + }) + }) + + it('formats guard content with all qualifier types', async () => { + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const options: StreamOptions = { + systemPrompt: [ + new GuardContentBlock({ + text: { + qualifiers: ['grounding_source', 'query', 'guard_content'], + text: 'Multi-qualifier guard content', + }, + }), + ], + } + + collectIterator(provider.stream(messages, options)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + modelId: 'global.anthropic.claude-sonnet-4-6', + messages: [ + { + role: 'user', + content: [{ text: 'Hello' }], + }, + ], + system: [ + { + guardContent: { + text: { + text: 'Multi-qualifier guard content', + qualifiers: ['grounding_source', 'query', 'guard_content'], + }, + }, + }, + ], + }) + }) + + it('formats guard content with image in system prompt', async () => { + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const imageBytes = new Uint8Array([1, 2, 3, 4]) + const options: StreamOptions = { + systemPrompt: [ + new GuardContentBlock({ + image: { + format: 'jpeg', + source: { bytes: imageBytes }, + }, + }), + ], + } + + collectIterator(provider.stream(messages, options)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + modelId: 'global.anthropic.claude-sonnet-4-6', + messages: [ + { + role: 'user', + content: [{ text: 'Hello' }], + }, + ], + system: [ + { + guardContent: { + image: { + format: 'jpeg', + source: { bytes: imageBytes }, + }, + }, + }, + ], + }) + }) + }) + + describe('guard content in messages', async () => { + const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) + + beforeEach(() => { + vi.clearAllMocks() + }) + + it('formats guard content with text in message', async () => { + const provider = new BedrockModel() + const messages = [ + new Message({ + role: 'user', + content: [ + new TextBlock('Verify this information:'), + new GuardContentBlock({ + text: { + qualifiers: ['grounding_source'], + text: 'The capital of France is Paris.', + }, + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + modelId: 'global.anthropic.claude-sonnet-4-6', + messages: [ + { + role: 'user', + content: [ + { text: 'Verify this information:' }, + { + guardContent: { + text: { + text: 'The capital of France is Paris.', + qualifiers: ['grounding_source'], + }, + }, + }, + ], + }, + ], + }) + }) + + it('formats guard content with image in message', async () => { + const provider = new BedrockModel() + const imageBytes = new Uint8Array([1, 2, 3, 4]) + const messages = [ + new Message({ + role: 'user', + content: [ + new TextBlock('Is this image safe?'), + new GuardContentBlock({ + image: { + format: 'jpeg', + source: { bytes: imageBytes }, + }, + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + modelId: 'global.anthropic.claude-sonnet-4-6', + messages: [ + { + role: 'user', + content: [ + { text: 'Is this image safe?' }, + { + guardContent: { + image: { + format: 'jpeg', + source: { bytes: imageBytes }, + }, + }, + }, + ], + }, + ], + }) + }) + }) + + describe('media blocks in tool results', () => { + const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) + + it('formats image block in tool result', async () => { + const provider = new BedrockModel() + const imageBytes = new Uint8Array([1, 2, 3]) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new ImageBlock({ format: 'png', source: { bytes: imageBytes } })], + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + toolResult: { + toolUseId: 'tool-1', + content: [{ image: { format: 'png', source: { bytes: imageBytes } } }], + status: 'success', + }, + }, + ], + }, + ], + }) + ) + }) + + it('formats video block in tool result with 3gp format mapping', async () => { + const provider = new BedrockModel() + const videoBytes = new Uint8Array([4, 5, 6]) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new VideoBlock({ format: '3gp', source: { bytes: videoBytes } })], + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + toolResult: { + toolUseId: 'tool-1', + content: [{ video: { format: 'three_gp', source: { bytes: videoBytes } } }], + status: 'success', + }, + }, + ], + }, + ], + }) + ) + }) + + it('formats document block in tool result', async () => { + const provider = new BedrockModel() + const docBytes = new Uint8Array([7, 8, 9]) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new DocumentBlock({ name: 'report.pdf', format: 'pdf', source: { bytes: docBytes } })], + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + toolResult: { + toolUseId: 'tool-1', + content: [{ document: { name: 'report.pdf', format: 'pdf', source: { bytes: docBytes } } }], + status: 'success', + }, + }, + ], + }, + ], + }) + ) + }) + + it('formats mixed text and media content in tool result', async () => { + const provider = new BedrockModel() + const imageBytes = new Uint8Array([1, 2]) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [ + new TextBlock('Here is the image:'), + new ImageBlock({ format: 'jpeg', source: { bytes: imageBytes } }), + ], + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + toolResult: { + toolUseId: 'tool-1', + content: [ + { text: 'Here is the image:' }, + { image: { format: 'jpeg', source: { bytes: imageBytes } } }, + ], + status: 'success', + }, + }, + ], + }, + ], + }) + ) + }) + }) + + describe('media blocks in messages', () => { + const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) + + it('formats top-level image block', async () => { + const provider = new BedrockModel() + const imageBytes = new Uint8Array([1, 2, 3]) + const messages = [ + new Message({ + role: 'user', + content: [new ImageBlock({ format: 'png', source: { bytes: imageBytes } })], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [{ image: { format: 'png', source: { bytes: imageBytes } } }], + }, + ], + }) + ) + }) + + it('formats top-level image block with S3 source', async () => { + const provider = new BedrockModel() + const messages = [ + new Message({ + role: 'user', + content: [ + new ImageBlock({ format: 'png', source: { location: { type: 's3', uri: 's3://bucket/image.png' } } }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [{ image: { format: 'png', source: { s3Location: { uri: 's3://bucket/image.png' } } } }], + }, + ], + }) + ) + }) + + it('formats top-level video block with 3gp format mapping', async () => { + const provider = new BedrockModel() + const videoBytes = new Uint8Array([4, 5, 6]) + const messages = [ + new Message({ + role: 'user', + content: [new VideoBlock({ format: '3gp', source: { bytes: videoBytes } })], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [{ video: { format: 'three_gp', source: { bytes: videoBytes } } }], + }, + ], + }) + ) + }) + + it('formats top-level document block with text source converted to bytes', async () => { + const provider = new BedrockModel() + const messages = [ + new Message({ + role: 'user', + content: [new DocumentBlock({ name: 'notes.txt', format: 'txt', source: { text: 'Hello world' } })], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + document: { + name: 'notes.txt', + format: 'txt', + source: { bytes: new TextEncoder().encode('Hello world') }, + }, + }, + ], + }, + ], + }) + ) + }) + }) + + describe('citations content block formatting', () => { + const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) + + it('maps SDK CitationLocation types to Bedrock object-key format through formatting pipeline', async () => { + const provider = new BedrockModel() + const sdkCitations = [ + { + location: { type: 'documentChar' as const, documentIndex: 0, start: 150, end: 300 }, + source: 'doc-0', + sourceContent: [{ text: 'char source' }], + title: 'Text Document', + }, + { + location: { type: 'documentPage' as const, documentIndex: 0, start: 2, end: 3 }, + source: 'doc-0', + sourceContent: [{ text: 'page source' }], + title: 'PDF Document', + }, + { + location: { type: 'documentChunk' as const, documentIndex: 1, start: 5, end: 8 }, + source: 'doc-1', + sourceContent: [{ text: 'chunk source' }], + title: 'Chunked Document', + }, + { + location: { type: 'searchResult' as const, searchResultIndex: 0, start: 25, end: 150 }, + source: 'search-0', + sourceContent: [{ text: 'search source' }], + title: 'Search Result', + }, + { + location: { type: 'web' as const, url: 'https://example.com/doc', domain: 'example.com' }, + source: 'web-0', + sourceContent: [{ text: 'web source' }], + title: 'Web Page', + }, + ] + + const messages = [ + new Message({ + role: 'assistant', + content: [ + new CitationsBlock({ + citations: sdkCitations, + content: [{ text: 'generated text with all citation types' }], + }), + ], + }), + new Message({ + role: 'user', + content: [new TextBlock('Follow up')], + }), + ] + + collectIterator(provider.stream(messages)) + + // Bedrock wire format uses object-key discrimination + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'assistant', + content: [ + { + citationsContent: { + citations: [ + { + location: { documentChar: { documentIndex: 0, start: 150, end: 300 } }, + source: 'doc-0', + sourceContent: [{ text: 'char source' }], + title: 'Text Document', + }, + { + location: { documentPage: { documentIndex: 0, start: 2, end: 3 } }, + source: 'doc-0', + sourceContent: [{ text: 'page source' }], + title: 'PDF Document', + }, + { + location: { documentChunk: { documentIndex: 1, start: 5, end: 8 } }, + source: 'doc-1', + sourceContent: [{ text: 'chunk source' }], + title: 'Chunked Document', + }, + { + location: { + searchResultLocation: { searchResultIndex: 0, start: 25, end: 150 }, + }, + source: 'search-0', + sourceContent: [{ text: 'search source' }], + title: 'Search Result', + }, + { + location: { web: { url: 'https://example.com/doc', domain: 'example.com' } }, + source: 'web-0', + sourceContent: [{ text: 'web source' }], + title: 'Web Page', + }, + ], + content: [{ text: 'generated text with all citation types' }], + }, + }, + ], + }, + { + role: 'user', + content: [{ text: 'Follow up' }], + }, + ], + }) + ) + }) + }) + + describe('media blocks in tool results', () => { + const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) + + it('formats image block in tool result', async () => { + const provider = new BedrockModel() + const imageBytes = new Uint8Array([1, 2, 3]) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new ImageBlock({ format: 'png', source: { bytes: imageBytes } })], + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + toolResult: { + toolUseId: 'tool-1', + content: [{ image: { format: 'png', source: { bytes: imageBytes } } }], + status: 'success', + }, + }, + ], + }, + ], + }) + ) + }) + + it('formats video block in tool result with 3gp format mapping', async () => { + const provider = new BedrockModel() + const videoBytes = new Uint8Array([4, 5, 6]) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new VideoBlock({ format: '3gp', source: { bytes: videoBytes } })], + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + toolResult: { + toolUseId: 'tool-1', + content: [{ video: { format: 'three_gp', source: { bytes: videoBytes } } }], + status: 'success', + }, + }, + ], + }, + ], + }) + ) + }) + + it('formats document block in tool result', async () => { + const provider = new BedrockModel() + const docBytes = new Uint8Array([7, 8, 9]) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new DocumentBlock({ name: 'report.pdf', format: 'pdf', source: { bytes: docBytes } })], + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + toolResult: { + toolUseId: 'tool-1', + content: [{ document: { name: 'report.pdf', format: 'pdf', source: { bytes: docBytes } } }], + status: 'success', + }, + }, + ], + }, + ], + }) + ) + }) + + it('formats mixed text and media content in tool result', async () => { + const provider = new BedrockModel() + const imageBytes = new Uint8Array([1, 2]) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [ + new TextBlock('Here is the image:'), + new ImageBlock({ format: 'jpeg', source: { bytes: imageBytes } }), + ], + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + toolResult: { + toolUseId: 'tool-1', + content: [ + { text: 'Here is the image:' }, + { image: { format: 'jpeg', source: { bytes: imageBytes } } }, + ], + status: 'success', + }, + }, + ], + }, + ], + }) + ) + }) + }) + + describe('media blocks in messages', () => { + const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) + + it('formats top-level image block', async () => { + const provider = new BedrockModel() + const imageBytes = new Uint8Array([1, 2, 3]) + const messages = [ + new Message({ + role: 'user', + content: [new ImageBlock({ format: 'png', source: { bytes: imageBytes } })], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [{ image: { format: 'png', source: { bytes: imageBytes } } }], + }, + ], + }) + ) + }) + + it('formats top-level image block with S3 source', async () => { + const provider = new BedrockModel() + const messages = [ + new Message({ + role: 'user', + content: [ + new ImageBlock({ format: 'png', source: { location: { type: 's3', uri: 's3://bucket/image.png' } } }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [{ image: { format: 'png', source: { s3Location: { uri: 's3://bucket/image.png' } } } }], + }, + ], + }) + ) + }) + + it('formats top-level video block with 3gp format mapping', async () => { + const provider = new BedrockModel() + const videoBytes = new Uint8Array([4, 5, 6]) + const messages = [ + new Message({ + role: 'user', + content: [new VideoBlock({ format: '3gp', source: { bytes: videoBytes } })], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [{ video: { format: 'three_gp', source: { bytes: videoBytes } } }], + }, + ], + }) + ) + }) + + it('formats top-level document block with text source converted to bytes', async () => { + const provider = new BedrockModel() + const messages = [ + new Message({ + role: 'user', + content: [new DocumentBlock({ name: 'notes.txt', format: 'txt', source: { text: 'Hello world' } })], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + document: { + name: 'notes.txt', + format: 'txt', + source: { bytes: new TextEncoder().encode('Hello world') }, + }, + }, + ], + }, + ], + }) + ) + }) + }) + + describe('includeToolResultStatus configuration', async () => { + const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) + + describe('when includeToolResultStatus is true', () => { + it('always includes status field in tool results', async () => { + const provider = new BedrockModel({ includeToolResultStatus: true }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-123', + status: 'success', + content: [new TextBlock('Result')], + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + messages: [ + { + content: [ + { + toolResult: { + content: [{ text: 'Result' }], + status: 'success', + toolUseId: 'tool-123', + }, + }, + ], + role: 'user', + }, + ], + toolConfig: BEDROCK_NOOP_TOOL_CONFIG, + modelId: expect.any(String), + }) + }) + }) + + describe('when includeToolResultStatus is false', () => { + it('never includes status field in tool results', async () => { + const provider = new BedrockModel({ includeToolResultStatus: false }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-123', + status: 'success', + content: [new TextBlock('Result')], + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + messages: [ + { + content: [ + { + toolResult: { + content: [{ text: 'Result' }], + toolUseId: 'tool-123', + }, + }, + ], + role: 'user', + }, + ], + toolConfig: BEDROCK_NOOP_TOOL_CONFIG, + modelId: expect.any(String), + }) + }) + }) + + describe('when includeToolResultStatus is auto', () => { + it('includes status field for Claude models', async () => { + const provider = new BedrockModel({ + modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0', + includeToolResultStatus: 'auto', + }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-123', + status: 'success', + content: [new TextBlock('Result')], + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + messages: [ + { + content: [ + { + toolResult: { + content: [{ text: 'Result' }], + status: 'success', + toolUseId: 'tool-123', + }, + }, + ], + role: 'user', + }, + ], + toolConfig: BEDROCK_NOOP_TOOL_CONFIG, + modelId: 'anthropic.claude-3-5-sonnet-20241022-v2:0', + }) + }) + }) + + describe('when includeToolResultStatus is undefined (default)', () => { + it('follows auto logic for non-Claude models', async () => { + const provider = new BedrockModel({ + modelId: 'amazon.nova-lite-v1:0', + }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-123', + status: 'success', + content: [new TextBlock('Result')], + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith({ + messages: [ + { + content: [ + { + toolResult: { + content: [{ text: 'Result' }], + toolUseId: 'tool-123', + }, + }, + ], + role: 'user', + }, + ], + toolConfig: BEDROCK_NOOP_TOOL_CONFIG, + modelId: 'amazon.nova-lite-v1:0', + }) + }) + }) + }) + + describe('region configuration', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('uses explicit region when provided', async () => { + mockBedrockClientImplementation() + + const provider = new BedrockModel({ region: 'eu-west-1' }) + + // After applyDefaultRegion wraps the config functions, verify they still return the correct value + const regionResult = await provider['_client'].config.region() + expect(regionResult).toBe('eu-west-1') + }) + + it('defaults to us-west-2 when region is missing', async () => { + mockBedrockClientImplementation({ + region: async () => { + throw new Error('Region is missing') + }, + useFipsEndpoint: async () => { + throw new Error('Region is missing') + }, + }) + + const provider = new BedrockModel() + + // After applyDefaultRegion wraps the config functions + const regionResult = await provider['_client'].config.region() + expect(regionResult).toBe('us-west-2') + + const fipsResult = await provider['_client'].config.useFipsEndpoint() + expect(fipsResult).toBe(false) + }) + + it('rethrows other region errors', async () => { + mockBedrockClientImplementation({ + region: async () => { + throw new Error('Network error') + }, + }) + + const provider = new BedrockModel() + + // Should rethrow the error + await expect(provider['_client'].config.region()).rejects.toThrow('Network error') + }) + }) + + describe('guardrail configuration', () => { + const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('constructor', () => { + it('accepts guardrailConfig in options', () => { + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + }, + }) + expect(provider.getConfig().guardrailConfig).toStrictEqual({ + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + }) + }) + + it('accepts guardrailConfig with all options', () => { + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + trace: 'enabled_full', + streamProcessingMode: 'sync', + redaction: { + input: true, + inputMessage: '[Custom input redacted.]', + output: true, + outputMessage: '[Custom output redacted.]', + }, + }, + }) + expect(provider.getConfig().guardrailConfig).toStrictEqual({ + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + trace: 'enabled_full', + streamProcessingMode: 'sync', + redaction: { + input: true, + inputMessage: '[Custom input redacted.]', + output: true, + outputMessage: '[Custom output redacted.]', + }, + }) + }) + }) + + describe('request formatting', () => { + it('includes guardrailConfig in request with default trace', async () => { + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + }, + }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + trace: 'enabled', + }, + }) + ) + }) + + it('includes guardrailConfig in request with custom trace', async () => { + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + trace: 'disabled', + }, + }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + trace: 'disabled', + }, + }) + ) + }) + + it('includes streamProcessingMode when specified', async () => { + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + streamProcessingMode: 'sync', + }, + }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + trace: 'enabled', + streamProcessingMode: 'sync', + }, + }) + ) + }) + + it('does not include guardrailConfig when not configured', async () => { + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.not.objectContaining({ + guardrailConfig: expect.anything(), + }) + ) + }) + }) + + describe('blocked guardrail detection', () => { + it('detects blocked guardrail in inputAssessment', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { contentBlockStart: {} } + yield { contentBlockDelta: { delta: { text: 'Hello' } } } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'guardrail_intervened' } } + yield { + metadata: { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + trace: { + guardrail: { + inputAssessment: { + '1234': { + topicPolicy: { + topics: [{ name: 'Harmful', action: 'BLOCKED', detected: true }], + }, + }, + }, + }, + }, + }, + } + }) + + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + }, + }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const events = await collectIterator(provider.stream(messages)) + + const redactEvent = events.find((e) => e.type === 'modelRedactionEvent') + expect(redactEvent).toBeDefined() + expect(redactEvent).toStrictEqual({ + type: 'modelRedactionEvent', + inputRedaction: { replaceContent: '[User input redacted.]' }, + }) + }) + + it('detects blocked guardrail in outputAssessments', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { contentBlockStart: {} } + yield { contentBlockDelta: { delta: { text: 'Hello' } } } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'guardrail_intervened' } } + yield { + metadata: { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + trace: { + guardrail: { + outputAssessments: { + '1234': { + contentPolicy: { + filters: [{ type: 'VIOLENCE', action: 'BLOCKED', detected: true }], + }, + }, + }, + }, + }, + }, + } + }) + + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + }, + }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const events = await collectIterator(provider.stream(messages)) + + const redactEvent = events.find((e) => e.type === 'modelRedactionEvent') + expect(redactEvent).toBeDefined() + }) + + it('does not emit redaction events when guardrail not blocked', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { contentBlockStart: {} } + yield { contentBlockDelta: { delta: { text: 'Hello' } } } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'end_turn' } } + yield { + metadata: { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + trace: { + guardrail: { + inputAssessment: { + '1234': { + topicPolicy: { + topics: [{ name: 'Safe', action: 'NONE', detected: false }], + }, + }, + }, + }, + }, + }, + } + }) + + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + }, + }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const events = await collectIterator(provider.stream(messages)) + + const redactEvent = events.find((e) => e.type === 'modelRedactionEvent') + expect(redactEvent).toBeUndefined() + }) + + it('does not emit redaction events without guardrailConfig', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { contentBlockStart: {} } + yield { contentBlockDelta: { delta: { text: 'Hello' } } } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'guardrail_intervened' } } + yield { + metadata: { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + trace: { + guardrail: { + inputAssessment: { + '1234': { + topicPolicy: { + topics: [{ name: 'Harmful', action: 'BLOCKED', detected: true }], + }, + }, + }, + }, + }, + }, + } + }) + + const provider = new BedrockModel() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const events = await collectIterator(provider.stream(messages)) + + const redactEvent = events.find((e) => e.type === 'modelRedactionEvent') + expect(redactEvent).toBeUndefined() + }) + }) + + describe('redaction event generation', () => { + it('emits input redaction with default message', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { messageStop: { stopReason: 'guardrail_intervened' } } + yield { + metadata: { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + trace: { + guardrail: { + inputAssessment: { '1': { topicPolicy: { topics: [{ action: 'BLOCKED', detected: true }] } } }, + }, + }, + }, + } + }) + + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'id', + guardrailVersion: '1', + }, + }) + const events = await collectIterator( + provider.stream([new Message({ role: 'user', content: [new TextBlock('Hello')] })]) + ) + + expect(events).toContainEqual({ + type: 'modelRedactionEvent', + inputRedaction: { replaceContent: '[User input redacted.]' }, + }) + }) + + it('emits input redaction with custom message', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { messageStop: { stopReason: 'guardrail_intervened' } } + yield { + metadata: { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + trace: { + guardrail: { + inputAssessment: { '1': { topicPolicy: { topics: [{ action: 'BLOCKED', detected: true }] } } }, + }, + }, + }, + } + }) + + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'id', + guardrailVersion: '1', + redaction: { + inputMessage: '[Custom input message]', + }, + }, + }) + const events = await collectIterator( + provider.stream([new Message({ role: 'user', content: [new TextBlock('Hello')] })]) + ) + + expect(events).toContainEqual({ + type: 'modelRedactionEvent', + inputRedaction: { replaceContent: '[Custom input message]' }, + }) + }) + + it('does not emit input redaction when redactInput is false', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { messageStop: { stopReason: 'guardrail_intervened' } } + yield { + metadata: { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + trace: { + guardrail: { + inputAssessment: { '1': { topicPolicy: { topics: [{ action: 'BLOCKED', detected: true }] } } }, + }, + }, + }, + } + }) + + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'id', + guardrailVersion: '1', + redaction: { + input: false, + }, + }, + }) + const events = await collectIterator( + provider.stream([new Message({ role: 'user', content: [new TextBlock('Hello')] })]) + ) + + const inputRedactEvent = events.find((e) => e.type === 'modelRedactionEvent' && 'inputRedaction' in e) + expect(inputRedactEvent).toBeUndefined() + }) + + it('emits output redaction when redactOutput is true', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { messageStop: { stopReason: 'guardrail_intervened' } } + yield { + metadata: { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + trace: { + guardrail: { + inputAssessment: { '1': { topicPolicy: { topics: [{ action: 'BLOCKED', detected: true }] } } }, + }, + }, + }, + } + }) + + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'id', + guardrailVersion: '1', + redaction: { + output: true, + }, + }, + }) + const events = await collectIterator( + provider.stream([new Message({ role: 'user', content: [new TextBlock('Hello')] })]) + ) + + expect(events).toContainEqual({ + type: 'modelRedactionEvent', + outputRedaction: { replaceContent: '[Assistant output redacted.]' }, + }) + }) + + it('emits output redaction with custom message', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { messageStop: { stopReason: 'guardrail_intervened' } } + yield { + metadata: { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + trace: { + guardrail: { + inputAssessment: { '1': { topicPolicy: { topics: [{ action: 'BLOCKED', detected: true }] } } }, + }, + }, + }, + } + }) + + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'id', + guardrailVersion: '1', + redaction: { + output: true, + outputMessage: '[Custom output message]', + }, + }, + }) + const events = await collectIterator( + provider.stream([new Message({ role: 'user', content: [new TextBlock('Hello')] })]) + ) + + expect(events).toContainEqual({ + type: 'modelRedactionEvent', + outputRedaction: { replaceContent: '[Custom output message]' }, + }) + }) + + it('emits both input and output redaction when both are enabled', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { messageStop: { stopReason: 'guardrail_intervened' } } + yield { + metadata: { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + trace: { + guardrail: { + inputAssessment: { '1': { topicPolicy: { topics: [{ action: 'BLOCKED', detected: true }] } } }, + }, + }, + }, + } + }) + + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'id', + guardrailVersion: '1', + redaction: { + input: true, + output: true, + }, + }, + }) + const events = await collectIterator( + provider.stream([new Message({ role: 'user', content: [new TextBlock('Hello')] })]) + ) + + expect(events).toContainEqual({ + type: 'modelRedactionEvent', + inputRedaction: { replaceContent: '[User input redacted.]' }, + }) + expect(events).toContainEqual({ + type: 'modelRedactionEvent', + outputRedaction: { replaceContent: '[Assistant output redacted.]' }, + }) + }) + + it('includes redactedContent from modelOutput when available', async () => { + setupMockSend(async function* () { + yield { messageStart: { role: 'assistant' } } + yield { contentBlockStart: {} } + yield { contentBlockDelta: { delta: { text: 'This content was blocked' } } } + yield { contentBlockStop: {} } + yield { messageStop: { stopReason: 'guardrail_intervened' } } + yield { + metadata: { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + trace: { + guardrail: { + modelOutput: ['This content ', 'was blocked'], + outputAssessments: { + '0': [{ topicPolicy: { topics: [{ action: 'BLOCKED', detected: true }] } }], + }, + }, + }, + }, + } + }) + + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'id', + guardrailVersion: '1', + redaction: { + output: true, + outputMessage: '[Blocked]', + }, + }, + }) + const events = await collectIterator( + provider.stream([new Message({ role: 'user', content: [new TextBlock('Hello')] })]) + ) + + expect(events).toContainEqual({ + type: 'modelRedactionEvent', + outputRedaction: { + replaceContent: '[Blocked]', + redactedContent: 'This content was blocked', + }, + }) + }) + }) + + describe('non-streaming mode', () => { + it('emits redaction events in non-streaming mode when guardrail blocks', async () => { + const mockSend = vi.fn(async () => ({ + output: { + message: { + role: 'assistant', + content: [{ text: 'Hello' }], + }, + }, + stopReason: 'guardrail_intervened', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + trace: { + guardrail: { + inputAssessment: { '1': { topicPolicy: { topics: [{ action: 'BLOCKED', detected: true }] } } }, + }, + }, + })) + mockBedrockClientImplementation({ send: mockSend }) + + const provider = new BedrockModel({ + stream: false, + guardrailConfig: { + guardrailIdentifier: 'id', + guardrailVersion: '1', + }, + }) + const events = await collectIterator( + provider.stream([new Message({ role: 'user', content: [new TextBlock('Hello')] })]) + ) + + expect(events).toContainEqual({ + type: 'modelRedactionEvent', + inputRedaction: { replaceContent: '[User input redacted.]' }, + }) + }) + }) + + describe('guardLatestUserMessage', () => { + const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) + + beforeEach(() => { + vi.clearAllMocks() + }) + + it('accepts guardLatestUserMessage in guardrailConfig', () => { + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + guardLatestUserMessage: true, + }, + }) + expect(provider.getConfig().guardrailConfig).toStrictEqual({ + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + guardLatestUserMessage: true, + }) + }) + + it('wraps latest user message text content in guardContent when enabled', async () => { + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + guardLatestUserMessage: true, + }, + }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello world')] })] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + guardContent: { + text: { + text: 'Hello world', + }, + }, + }, + ], + }, + ], + }) + ) + }) + + it('wraps latest user message image content in guardContent when enabled', async () => { + const imageBytes = new Uint8Array([1, 2, 3, 4]) + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + guardLatestUserMessage: true, + }, + }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ImageBlock({ + format: 'jpeg', + source: { bytes: imageBytes }, + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + guardContent: { + image: { + format: 'jpeg', + source: { bytes: imageBytes }, + }, + }, + }, + ], + }, + ], + }) + ) + }) + + it('does not wrap toolResult messages even though role is user', async () => { + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + guardLatestUserMessage: true, + }, + }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('What is 2+2?')] }), + new Message({ + role: 'assistant', + content: [ + new ToolUseBlock({ + name: 'calculator', + toolUseId: 'tool-123', + input: { expression: '2+2' }, + }), + ], + }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-123', + status: 'success', + content: [new TextBlock('4')], + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + // The latest message is a toolResult, but guardContent should wrap the FIRST user message + // which contains text, not the toolResult + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + guardContent: { + text: { + text: 'What is 2+2?', + }, + }, + }, + ], + }, + { + role: 'assistant', + content: [ + { + toolUse: { + name: 'calculator', + toolUseId: 'tool-123', + input: { expression: '2+2' }, + }, + }, + ], + }, + { + role: 'user', + content: [ + { + toolResult: expect.objectContaining({ + toolUseId: 'tool-123', + }), + }, + ], + }, + ], + }) + ) + }) + + it('does not wrap messages when guardLatestUserMessage is false', async () => { + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + guardLatestUserMessage: false, + }, + }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello world')] })] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [{ text: 'Hello world' }], + }, + ], + }) + ) + }) + + it('does not wrap messages when guardLatestUserMessage is undefined', async () => { + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + }, + }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello world')] })] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [{ text: 'Hello world' }], + }, + ], + }) + ) + }) + + it('does not wrap assistant messages', async () => { + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + guardLatestUserMessage: true, + }, + }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Hello')] }), + new Message({ role: 'assistant', content: [new TextBlock('Hi there!')] }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + guardContent: { + text: { + text: 'Hello', + }, + }, + }, + ], + }, + { + role: 'assistant', + content: [{ text: 'Hi there!' }], + }, + ], + }) + ) + }) + + it('wraps only the last user text/image message in multi-turn conversation', async () => { + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + guardLatestUserMessage: true, + }, + }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('First message')] }), + new Message({ role: 'assistant', content: [new TextBlock('First response')] }), + new Message({ role: 'user', content: [new TextBlock('Second message')] }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [{ text: 'First message' }], + }, + { + role: 'assistant', + content: [{ text: 'First response' }], + }, + { + role: 'user', + content: [ + { + guardContent: { + text: { + text: 'Second message', + }, + }, + }, + ], + }, + ], + }) + ) + }) + + it('handles no user messages with text/image content gracefully', async () => { + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + guardLatestUserMessage: true, + }, + }) + // Only assistant message, no user text/image content + const messages = [new Message({ role: 'assistant', content: [new TextBlock('Hello!')] })] + + collectIterator(provider.stream(messages)) + + // Should not throw and should not wrap anything + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'assistant', + content: [{ text: 'Hello!' }], + }, + ], + }) + ) + }) + + it('preserves explicit GuardContentBlock in messages without double-wrapping', async () => { + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + guardLatestUserMessage: true, + }, + }) + const messages = [ + new Message({ + role: 'user', + content: [ + new GuardContentBlock({ + text: { + qualifiers: ['grounding_source'], + text: 'Already guarded content', + }, + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + // Explicit GuardContentBlock should be preserved as-is (no text/image content to wrap) + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + guardContent: { + text: { + text: 'Already guarded content', + qualifiers: ['grounding_source'], + }, + }, + }, + ], + }, + ], + }) + ) + }) + + it('wraps all text and image blocks in the latest user message', async () => { + const imageBytes = new Uint8Array([5, 6, 7, 8]) + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + guardLatestUserMessage: true, + }, + }) + const messages = [ + new Message({ + role: 'user', + content: [ + new TextBlock('Check this text'), + new ImageBlock({ + format: 'png', + source: { bytes: imageBytes }, + }), + new TextBlock('And this text too'), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + guardContent: { + text: { + text: 'Check this text', + }, + }, + }, + { + guardContent: { + image: { + format: 'png', + source: { bytes: imageBytes }, + }, + }, + }, + { + guardContent: { + text: { + text: 'And this text too', + }, + }, + }, + ], + }, + ], + }) + ) + }) + + it('skips wrapping images with unsupported formats (gif)', async () => { + const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const imageBytes = new Uint8Array([1, 2, 3, 4]) + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + guardLatestUserMessage: true, + }, + }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ImageBlock({ + format: 'gif', + source: { bytes: imageBytes }, + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(consoleWarnSpy).toHaveBeenCalledWith( + 'image_format= | format not supported by bedrock guardrails | skipping guardContent wrap' + ) + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + image: { + format: 'gif', + source: { bytes: imageBytes }, + }, + }, + ], + }, + ], + }) + ) + consoleWarnSpy.mockRestore() + }) + + it('skips wrapping images with unsupported formats (webp)', async () => { + const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const imageBytes = new Uint8Array([1, 2, 3, 4]) + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + guardLatestUserMessage: true, + }, + }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ImageBlock({ + format: 'webp', + source: { bytes: imageBytes }, + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(consoleWarnSpy).toHaveBeenCalledWith( + 'image_format= | format not supported by bedrock guardrails | skipping guardContent wrap' + ) + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + image: { + format: 'webp', + source: { bytes: imageBytes }, + }, + }, + ], + }, + ], + }) + ) + consoleWarnSpy.mockRestore() + }) + + it('skips wrapping images with S3 source', async () => { + const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + guardLatestUserMessage: true, + }, + }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ImageBlock({ + format: 'png', + source: { + location: { + type: 's3', + uri: 's3://bucket/image.png', + }, + }, + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(consoleWarnSpy).toHaveBeenCalledWith( + 'source_type= | image source must be bytes for bedrock guardrails | skipping guardContent wrap' + ) + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + image: { + format: 'png', + source: { + s3Location: { + uri: 's3://bucket/image.png', + }, + }, + }, + }, + ], + }, + ], + }) + ) + consoleWarnSpy.mockRestore() + }) + + it('skips wrapping images with URL source', async () => { + const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + guardLatestUserMessage: true, + }, + }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ImageBlock({ + format: 'jpeg', + source: { url: 'https://example.com/image.jpg' }, + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + // URL sources return undefined in _formatMediaSource, resulting in source: undefined + expect(consoleWarnSpy).toHaveBeenCalledWith( + 'source_type= | not supported by bedrock | skipping' + ) + // The image block still appears but with undefined source (Bedrock will reject this) + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + image: { + format: 'jpeg', + source: undefined, + }, + }, + ], + }, + ], + }) + ) + consoleWarnSpy.mockRestore() + }) + + it('wraps supported image formats (png and jpeg) with bytes source', async () => { + const imageBytes = new Uint8Array([1, 2, 3, 4]) + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + guardLatestUserMessage: true, + }, + }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ImageBlock({ + format: 'png', + source: { bytes: imageBytes }, + }), + new ImageBlock({ + format: 'jpeg', + source: { bytes: imageBytes }, + }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + guardContent: { + image: { + format: 'png', + source: { bytes: imageBytes }, + }, + }, + }, + { + guardContent: { + image: { + format: 'jpeg', + source: { bytes: imageBytes }, + }, + }, + }, + ], + }, + ], + }) + ) + }) + + it('does not wrap reasoning or cachePoint blocks', async () => { + const provider = new BedrockModel({ + guardrailConfig: { + guardrailIdentifier: 'my-guardrail-id', + guardrailVersion: '1', + guardLatestUserMessage: true, + }, + }) + const messages = [ + new Message({ + role: 'user', + content: [ + new TextBlock('User message'), + new ReasoningBlock({ text: 'thinking...', signature: 'sig' }), + new CachePointBlock({ cacheType: 'default' }), + ], + }), + ] + + collectIterator(provider.stream(messages)) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: [ + { + role: 'user', + content: [ + { + guardContent: { + text: { + text: 'User message', + }, + }, + }, + { + reasoningContent: { + reasoningText: { + text: 'thinking...', + signature: 'sig', + }, + }, + }, + { cachePoint: { type: 'default' } }, + ], + }, + ], + }) + ) + }) + }) + }) + + describe('thinking with forced tool choice', () => { + const mockConverseStreamCommand = vi.mocked(ConverseStreamCommand) + + const provider = new BedrockModel({ + modelId: 'anthropic.claude-sonnet-4-20250514-v1:0', + additionalRequestFields: { + thinking: { type: 'enabled', budget_tokens: 5000 }, + some_other_field: 'value', + }, + }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + const toolSpecs = [{ name: 'test_tool', description: 'test' }] + + it.each([ + { name: 'any', toolChoice: { any: {} } }, + { name: 'tool', toolChoice: { tool: { name: 'test_tool' } } }, + ])('strips thinking from additional request fields when tool choice is $name', ({ toolChoice }) => { + collectIterator(provider.stream(messages, { toolSpecs, toolChoice })) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + additionalModelRequestFields: { some_other_field: 'value' }, + }) + ) + }) + + it('preserves thinking when tool choice is auto', () => { + collectIterator(provider.stream(messages, { toolSpecs, toolChoice: { auto: {} } })) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + additionalModelRequestFields: { + thinking: { type: 'enabled', budget_tokens: 5000 }, + some_other_field: 'value', + }, + }) + ) + }) + + it('preserves thinking when no tool choice is provided', () => { + collectIterator(provider.stream(messages, { toolSpecs })) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ + additionalModelRequestFields: { + thinking: { type: 'enabled', budget_tokens: 5000 }, + some_other_field: 'value', + }, + }) + ) + }) + + it('omits additionalModelRequestFields when thinking is the only field and tool choice forces tool use', () => { + const thinkingOnlyProvider = new BedrockModel({ + modelId: 'anthropic.claude-sonnet-4-20250514-v1:0', + additionalRequestFields: { + thinking: { type: 'enabled', budget_tokens: 5000 }, + }, + }) + + collectIterator(thinkingOnlyProvider.stream(messages, { toolSpecs, toolChoice: { any: {} } })) + + expect(mockConverseStreamCommand).toHaveBeenLastCalledWith( + expect.not.objectContaining({ + additionalModelRequestFields: expect.anything(), + }) + ) + }) + }) + + describe('countTokens', () => { + const messages: Message[] = [new Message({ role: 'user', content: [new TextBlock('hello')] })] + const toolSpecs = [ + { name: 'test_tool', description: 'A test tool', inputSchema: { type: 'object' as const, properties: {} } }, + ] + + beforeEach(() => { + vi.clearAllMocks() + BedrockModel.clearCountTokensCache() + }) + + it('should use heuristic by default when useNativeTokenCount is not set', async () => { + const mockSend = vi.fn() + mockBedrockClientImplementation({ send: mockSend }) + const model = new BedrockModel() + + const result = await model.countTokens(messages) + + expect(mockSend).not.toHaveBeenCalled() + expect(result).toBe(2) // heuristic: Math.ceil('hello'.length / 4) + }) + + it('should return native token count on success', async () => { + const mockSend = vi.fn(async () => ({ inputTokens: 42 })) + mockBedrockClientImplementation({ send: mockSend }) + const model = new BedrockModel({ useNativeTokenCount: true }) + + const result = await model.countTokens(messages) + + expect(result).toBe(42) + expect(mockSend).toHaveBeenCalledOnce() + }) + + it('should include system prompt in request', async () => { + const mockSend = vi.fn(async () => ({ inputTokens: 55 })) + mockBedrockClientImplementation({ send: mockSend }) + const model = new BedrockModel({ useNativeTokenCount: true }) + + const result = await model.countTokens(messages, { systemPrompt: 'Be helpful.' }) + + expect(result).toBe(55) + const commandInput = vi.mocked(CountTokensCommand).mock.calls[0]![0]! + expect(commandInput).toStrictEqual({ + modelId: expect.any(String), + input: { + converse: { + messages: [{ role: 'user', content: [{ text: 'hello' }] }], + system: [{ text: 'Be helpful.' }], + }, + }, + }) + }) + + it('should include tool specs in request', async () => { + const mockSend = vi.fn(async () => ({ inputTokens: 100 })) + mockBedrockClientImplementation({ send: mockSend }) + const model = new BedrockModel({ useNativeTokenCount: true }) + + const result = await model.countTokens(messages, { toolSpecs }) + + expect(result).toBe(100) + const commandInput = vi.mocked(CountTokensCommand).mock.calls[0]![0]! + expect(commandInput).toStrictEqual({ + modelId: expect.any(String), + input: { + converse: { + messages: [{ role: 'user', content: [{ text: 'hello' }] }], + toolConfig: { + tools: [ + { + toolSpec: { + name: 'test_tool', + description: 'A test tool', + inputSchema: { json: { type: 'object', properties: {} } }, + }, + }, + ], + }, + }, + }, + }) + }) + + it('should strip inferenceConfig from request', async () => { + const mockSend = vi.fn(async () => ({ inputTokens: 10 })) + mockBedrockClientImplementation({ send: mockSend }) + const model = new BedrockModel({ maxTokens: 100, useNativeTokenCount: true }) + + await model.countTokens(messages) + + const commandInput = vi.mocked(CountTokensCommand).mock.calls[0]![0]! + expect(commandInput).toStrictEqual({ + modelId: expect.any(String), + input: { + converse: { + messages: [{ role: 'user', content: [{ text: 'hello' }] }], + }, + }, + }) + }) + + it('should fall back to estimation on API error', async () => { + const mockSend = vi.fn(async () => { + throw new Error('API error') + }) + mockBedrockClientImplementation({ send: mockSend }) + const model = new BedrockModel({ useNativeTokenCount: true }) + + const result = await model.countTokens(messages) + + expect(typeof result).toBe('number') + expect(result).toBeGreaterThanOrEqual(0) + }) + + it('should fall back to estimation on generic exception', async () => { + const mockSend = vi.fn(async () => { + throw new Error('Connection failed') + }) + mockBedrockClientImplementation({ send: mockSend }) + const model = new BedrockModel({ useNativeTokenCount: true }) + + const result = await model.countTokens(messages) + + expect(typeof result).toBe('number') + expect(result).toBeGreaterThanOrEqual(0) + }) + + it('should cache model ID and skip API call when model does not support counting tokens', async () => { + const unsupportedError = new Error("The provided model doesn't support counting tokens") + unsupportedError.name = 'ValidationException' + const mockSend = vi.fn(async () => { + throw unsupportedError + }) + mockBedrockClientImplementation({ send: mockSend }) + const model = new BedrockModel({ useNativeTokenCount: true }) + + // First call: hits API, gets error, caches + await model.countTokens(messages) + expect(mockSend).toHaveBeenCalledOnce() + + // Second call: skips API entirely + await model.countTokens(messages) + expect(mockSend).toHaveBeenCalledOnce() + }) + + it('should cache model ID and skip API call on AccessDeniedException', async () => { + const accessDeniedError = new Error( + 'User: arn:aws:sts::123456789012:assumed-role/role is not authorized to perform: bedrock:CountTokens' + ) + accessDeniedError.name = 'AccessDeniedException' + const mockSend = vi.fn(async () => { + throw accessDeniedError + }) + mockBedrockClientImplementation({ send: mockSend }) + const model = new BedrockModel({ useNativeTokenCount: true }) + + // First call: hits API, gets AccessDeniedException, caches + await model.countTokens(messages) + expect(mockSend).toHaveBeenCalledOnce() + + // Second call: skips API entirely due to caching + await model.countTokens(messages) + expect(mockSend).toHaveBeenCalledOnce() + }) + + it('should not cache model ID for other errors', async () => { + const mockSend = vi.fn(async () => { + throw new Error('Transient network error') + }) + mockBedrockClientImplementation({ send: mockSend }) + const model = new BedrockModel({ useNativeTokenCount: true }) + + await model.countTokens(messages) + expect(mockSend).toHaveBeenCalledTimes(1) + + // Second call should still attempt the API + await model.countTokens(messages) + expect(mockSend).toHaveBeenCalledTimes(2) + }) + + it('should skip native API and use heuristic when useNativeTokenCount is false', async () => { + const mockSend = vi.fn() + mockBedrockClientImplementation({ send: mockSend }) + const model = new BedrockModel({ useNativeTokenCount: false }) + + const result = await model.countTokens(messages) + + expect(mockSend).not.toHaveBeenCalled() + expect(result).toBe(2) // heuristic: Math.ceil('hello'.length / 4) + }) + }) +}) diff --git a/strands-ts/src/models/__tests__/defaults.test.ts b/strands-ts/src/models/__tests__/defaults.test.ts new file mode 100644 index 0000000000..6236a47888 --- /dev/null +++ b/strands-ts/src/models/__tests__/defaults.test.ts @@ -0,0 +1,39 @@ +import { describe, it, expect } from 'vitest' +import { getContextWindowLimit } from '../defaults.js' + +describe('getContextWindowLimit', () => { + it('returns the context window limit for known model IDs across all providers', () => { + // Anthropic direct API + expect(getContextWindowLimit('claude-sonnet-4-6')).toBe(1_000_000) + expect(getContextWindowLimit('claude-opus-4-6')).toBe(1_000_000) + expect(getContextWindowLimit('claude-opus-4-5')).toBe(200_000) + expect(getContextWindowLimit('claude-haiku-4-5')).toBe(200_000) + // Bedrock Anthropic + expect(getContextWindowLimit('anthropic.claude-sonnet-4-6')).toBe(1_000_000) + // Bedrock Amazon Nova + expect(getContextWindowLimit('amazon.nova-pro-v1:0')).toBe(300_000) + expect(getContextWindowLimit('amazon.nova-micro-v1:0')).toBe(128_000) + // OpenAI + expect(getContextWindowLimit('gpt-5.4')).toBe(1_050_000) + expect(getContextWindowLimit('gpt-4o')).toBe(128_000) + expect(getContextWindowLimit('o3')).toBe(200_000) + expect(getContextWindowLimit('o4-mini')).toBe(200_000) + // Gemini + expect(getContextWindowLimit('gemini-2.5-flash')).toBe(1_048_576) + expect(getContextWindowLimit('gemini-2.5-pro')).toBe(1_048_576) + }) + + it('strips Bedrock cross-region prefix before lookup', () => { + expect(getContextWindowLimit('us.anthropic.claude-sonnet-4-6')).toBe(1_000_000) + expect(getContextWindowLimit('global.anthropic.claude-sonnet-4-6')).toBe(1_000_000) + }) + + it('does not strip unknown prefixes', () => { + expect(getContextWindowLimit('custom.gpt-5.4')).toBeUndefined() + }) + + it('returns undefined for unknown model IDs', () => { + expect(getContextWindowLimit('unknown-model-xyz')).toBeUndefined() + expect(getContextWindowLimit('us.unknown.model-v1:0')).toBeUndefined() + }) +}) diff --git a/strands-ts/src/models/__tests__/google.test.ts b/strands-ts/src/models/__tests__/google.test.ts new file mode 100644 index 0000000000..aacfc3c162 --- /dev/null +++ b/strands-ts/src/models/__tests__/google.test.ts @@ -0,0 +1,1351 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { GoogleGenAI, FunctionCallingConfigMode, type GenerateContentResponse } from '@google/genai' +import { collectIterator } from '../../__fixtures__/model-test-helpers.js' +import { GoogleModel } from '../google/model.js' +import { ContextWindowOverflowError, ModelThrottledError } from '../../errors.js' +import { + Message, + CachePointBlock, + GuardContentBlock, + ReasoningBlock, + TextBlock, + ToolResultBlock, + ToolUseBlock, +} from '../../types/messages.js' +import type { ContentBlock } from '../../types/messages.js' +import { formatMessages, mapChunkToEvents } from '../google/adapters.js' +import type { GoogleStreamState } from '../google/types.js' +import { ImageBlock, DocumentBlock, VideoBlock } from '../../types/media.js' +import { warnOnce } from '../../logging/warn-once.js' + +vi.mock('../../logging/warn-once.js', () => ({ + warnOnce: vi.fn(), +})) + +/** + * Helper to create a mock Gemini client with streaming support + */ +function createMockClient(streamGenerator: () => AsyncGenerator>): GoogleGenAI { + return { + models: { + generateContentStream: vi.fn(async () => streamGenerator()), + }, + } as unknown as GoogleGenAI +} + +/** + * Helper to create a mock Gemini client that captures the request parameters. + * Returns the client and a captured object with `config` and `contents` fields + * populated after a stream call. + */ +function createMockClientWithCapture(): { client: GoogleGenAI; captured: Record } { + const captured: Record = {} + const client = { + models: { + generateContentStream: vi.fn(async (params: Record) => { + Object.assign(captured, params) + return (async function* () { + yield { candidates: [{ finishReason: 'STOP' }] } + })() + }), + }, + } as unknown as GoogleGenAI + return { client, captured } +} + +/** + * Helper to set up a capture-based test with provider, captured params, and a default user message. + */ +function setupCaptureTest(): { + provider: GoogleModel + captured: Record + messages: Message[] +} { + const { client, captured } = createMockClientWithCapture() + const provider = new GoogleModel({ client }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + return { provider, captured, messages } +} + +/** + * Helper to set up a stream-based test with a mock client, provider, and default user message. + */ +function setupStreamTest(streamGenerator: () => AsyncGenerator>): { + provider: GoogleModel + messages: Message[] +} { + const client = createMockClient(streamGenerator) + const provider = new GoogleModel({ client }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + return { provider, messages } +} + +/** + * Helper to format a single content block via formatMessages. + */ +function formatBlock(block: ContentBlock, role: 'user' | 'assistant' = 'user'): ReturnType { + return formatMessages([new Message({ role, content: [block] })]) +} + +describe('GoogleModel', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.stubEnv('GEMINI_API_KEY', 'test-api-key') + }) + + describe('constructor', () => { + it('creates instance with API key', () => { + const provider = new GoogleModel({ apiKey: 'test-key', modelId: 'gemini-2.0-flash' }) + expect(provider.getConfig().modelId).toBe('gemini-2.0-flash') + }) + + it('throws error when no API key provided and no env variable', () => { + vi.stubEnv('GEMINI_API_KEY', '') + + expect(() => new GoogleModel()).toThrow('Gemini API key is required') + }) + + it('does not require API key when client is provided', () => { + vi.stubEnv('GEMINI_API_KEY', '') + + const mockClient = createMockClient(async function* () { + yield { candidates: [{ finishReason: 'STOP' }] } + }) + + expect(() => new GoogleModel({ client: mockClient })).not.toThrow() + }) + + it('warns when modelId is not explicitly set', () => { + new GoogleModel({ apiKey: 'test-key' }) + expect(warnOnce).toHaveBeenCalledWith( + expect.objectContaining({ warn: expect.any(Function) }), + expect.stringContaining('using default modelId') + ) + }) + + it('does not warn when modelId is explicitly set', () => { + new GoogleModel({ apiKey: 'test-key', modelId: 'gemini-2.5-flash' }) + expect(warnOnce).not.toHaveBeenCalledWith( + expect.objectContaining({ warn: expect.any(Function) }), + expect.stringContaining('using default modelId') + ) + }) + }) + + describe('updateConfig', () => { + it('merges new config with existing config', () => { + const provider = new GoogleModel({ apiKey: 'test-key', modelId: 'gemini-2.5-flash' }) + provider.updateConfig({ params: { temperature: 0.5 } }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'gemini-2.5-flash', + params: { temperature: 0.5 }, + contextWindowLimit: 1_048_576, + }) + }) + + it('re-resolves contextWindowLimit when modelId changes and it was auto-resolved', () => { + const provider = new GoogleModel({ apiKey: 'test-key' }) + expect(provider.getConfig().contextWindowLimit).toBe(1_048_576) + + provider.updateConfig({ modelId: 'gemini-2.0-flash' }) + expect(provider.getConfig().contextWindowLimit).toBe(1_048_576) + }) + + it('clears contextWindowLimit when modelId changes to unknown model', () => { + const provider = new GoogleModel({ apiKey: 'test-key' }) + expect(provider.getConfig().contextWindowLimit).toBe(1_048_576) + + provider.updateConfig({ modelId: 'my-custom-finetuned-model' }) + expect(provider.getConfig().contextWindowLimit).toBeUndefined() + }) + + it('preserves explicit contextWindowLimit when modelId changes', () => { + const provider = new GoogleModel({ apiKey: 'test-key', contextWindowLimit: 50_000 }) + expect(provider.getConfig().contextWindowLimit).toBe(50_000) + + provider.updateConfig({ modelId: 'gemini-2.0-flash' }) + expect(provider.getConfig().contextWindowLimit).toBe(50_000) + }) + }) + + describe('getConfig', () => { + it('returns the current configuration', () => { + const provider = new GoogleModel({ + apiKey: 'test-key', + modelId: 'gemini-2.5-flash', + params: { maxOutputTokens: 1024, temperature: 0.7 }, + }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'gemini-2.5-flash', + params: { maxOutputTokens: 1024, temperature: 0.7 }, + contextWindowLimit: 1_048_576, + }) + }) + + it('includes contextWindowLimit in config when provided', () => { + const provider = new GoogleModel({ + apiKey: 'test-key', + modelId: 'gemini-2.5-flash', + contextWindowLimit: 1_048_576, + }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'gemini-2.5-flash', + contextWindowLimit: 1_048_576, + }) + }) + + it('auto-populates contextWindowLimit from model ID lookup', () => { + const provider = new GoogleModel({ apiKey: 'test-key', modelId: 'gemini-2.5-pro' }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'gemini-2.5-pro', + contextWindowLimit: 1_048_576, + }) + }) + + it('auto-populates contextWindowLimit for default model ID', () => { + const provider = new GoogleModel({ apiKey: 'test-key' }) + expect(provider.getConfig()).toStrictEqual({ + contextWindowLimit: 1_048_576, + }) + }) + + it('does not override explicit contextWindowLimit', () => { + const provider = new GoogleModel({ + apiKey: 'test-key', + modelId: 'gemini-2.5-flash', + contextWindowLimit: 500_000, + }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'gemini-2.5-flash', + contextWindowLimit: 500_000, + }) + }) + + it('leaves contextWindowLimit undefined for unknown model IDs', () => { + const provider = new GoogleModel({ apiKey: 'test-key', modelId: 'unknown-model' }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'unknown-model', + }) + }) + }) + + describe('stream', () => { + it('throws error when messages array is empty', async () => { + const provider = new GoogleModel({ apiKey: 'test-key' }) + + await expect(collectIterator(provider.stream([]))).rejects.toThrow('At least one message is required') + }) + + it('emits message start and stop events', async () => { + const { provider, messages } = setupStreamTest(async function* () { + yield { + candidates: [ + { + content: { parts: [{ text: 'Hello' }] }, + }, + ], + } + yield { candidates: [{ finishReason: 'STOP' }] } + }) + + const events = await collectIterator(provider.stream(messages)) + + expect(events[0]).toEqual({ type: 'modelMessageStartEvent', role: 'assistant' }) + expect(events[events.length - 1]).toEqual({ type: 'modelMessageStopEvent', stopReason: 'endTurn' }) + }) + + it('emits text content block events', async () => { + const { provider, messages } = setupStreamTest(async function* () { + yield { + candidates: [ + { + content: { parts: [{ text: 'Hello' }] }, + }, + ], + } + yield { + candidates: [ + { + content: { parts: [{ text: ' world' }] }, + }, + ], + } + yield { candidates: [{ finishReason: 'STOP' }] } + }) + + const events = await collectIterator(provider.stream(messages)) + + expect(events).toHaveLength(6) + expect(events[0]).toEqual({ type: 'modelMessageStartEvent', role: 'assistant' }) + expect(events[1]).toEqual({ type: 'modelContentBlockStartEvent' }) + expect(events[2]).toEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'Hello' }, + }) + expect(events[3]).toEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: ' world' }, + }) + expect(events[4]).toEqual({ type: 'modelContentBlockStopEvent' }) + expect(events[5]).toEqual({ type: 'modelMessageStopEvent', stopReason: 'endTurn' }) + }) + + it('emits usage metadata when available', async () => { + const { provider, messages } = setupStreamTest(async function* () { + yield { + candidates: [ + { + content: { parts: [{ text: 'Hi' }] }, + }, + ], + usageMetadata: { + promptTokenCount: 10, + totalTokenCount: 15, + }, + } + yield { candidates: [{ finishReason: 'STOP' }] } + }) + + const events = await collectIterator(provider.stream(messages)) + + const metadataEvent = events.find((e) => e.type === 'modelMetadataEvent') + expect(metadataEvent).toEqual({ + type: 'modelMetadataEvent', + usage: { + inputTokens: 10, + outputTokens: 5, + totalTokens: 15, + }, + }) + }) + + it('handles MAX_TOKENS finish reason', async () => { + const { provider, messages } = setupStreamTest(async function* () { + yield { + candidates: [ + { + content: { parts: [{ text: 'Truncated' }] }, + }, + ], + } + yield { candidates: [{ finishReason: 'MAX_TOKENS' }] } + }) + + const events = await collectIterator(provider.stream(messages)) + + const stopEvent = events.find((e) => e.type === 'modelMessageStopEvent') + expect(stopEvent).toBeDefined() + expect(stopEvent!.stopReason).toBe('maxTokens') + }) + }) + + describe('error handling', () => { + it('throws ContextWindowOverflowError for context overflow errors', async () => { + const mockClient = { + models: { + generateContentStream: vi.fn(async () => { + throw new Error( + JSON.stringify({ + error: { + status: 'INVALID_ARGUMENT', + message: 'Request exceeds the maximum number of tokens allowed', + }, + }) + ) + }), + }, + } as unknown as GoogleGenAI + + const provider = new GoogleModel({ client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(collectIterator(provider.stream(messages))).rejects.toThrow(ContextWindowOverflowError) + }) + + it('throws ModelThrottledError for RESOURCE_EXHAUSTED status', async () => { + const mockClient = { + models: { + generateContentStream: vi.fn(async () => { + throw new Error( + JSON.stringify({ + error: { + status: 'RESOURCE_EXHAUSTED', + message: 'Quota exceeded for the model', + }, + }) + ) + }), + }, + } as unknown as GoogleGenAI + + const provider = new GoogleModel({ client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(collectIterator(provider.stream(messages))).rejects.toThrow(ModelThrottledError) + }) + + it('throws ModelThrottledError for UNAVAILABLE status', async () => { + const mockClient = { + models: { + generateContentStream: vi.fn(async () => { + throw new Error( + JSON.stringify({ + error: { + status: 'UNAVAILABLE', + message: 'Service temporarily unavailable', + }, + }) + ) + }), + }, + } as unknown as GoogleGenAI + + const provider = new GoogleModel({ client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(collectIterator(provider.stream(messages))).rejects.toThrow(ModelThrottledError) + }) + + it('rethrows unrecognized errors', async () => { + const mockClient = { + models: { + generateContentStream: vi.fn(async () => { + throw new Error('Network error') + }), + }, + } as unknown as GoogleGenAI + + const provider = new GoogleModel({ client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(collectIterator(provider.stream(messages))).rejects.toThrow('Network error') + }) + }) + + describe('system prompt', () => { + it('passes string system prompt to config', async () => { + const { provider, captured, messages } = setupCaptureTest() + + await collectIterator(provider.stream(messages, { systemPrompt: 'You are a helpful assistant' })) + + const config = captured.config as { systemInstruction?: string } + expect(config.systemInstruction).toBe('You are a helpful assistant') + }) + + it('ignores empty string system prompt', async () => { + const { provider, captured, messages } = setupCaptureTest() + + await collectIterator(provider.stream(messages, { systemPrompt: ' ' })) + + const config = captured.config as { systemInstruction?: string } + expect(config.systemInstruction).toBeUndefined() + }) + }) + + describe('message formatting', () => { + it('formats user messages correctly', async () => { + const { provider, captured } = setupCaptureTest() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + await collectIterator(provider.stream(messages)) + + const contents = captured.contents as Array<{ role: string; parts: Array<{ text: string }> }> + expect(contents).toHaveLength(1) + expect(contents[0]?.role).toBe('user') + expect(contents[0]?.parts[0]?.text).toBe('Hello') + }) + + it('formats assistant messages correctly', async () => { + const { provider, captured } = setupCaptureTest() + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Hi')] }), + new Message({ role: 'assistant', content: [new TextBlock('Hello!')] }), + new Message({ role: 'user', content: [new TextBlock('How are you?')] }), + ] + + await collectIterator(provider.stream(messages)) + + const contents = captured.contents as Array<{ role: string; parts: Array<{ text: string }> }> + expect(contents).toHaveLength(3) + expect(contents[0]?.role).toBe('user') + expect(contents[1]?.role).toBe('model') + expect(contents[2]?.role).toBe('user') + }) + }) + + describe('content type formatting', () => { + describe('image content', () => { + it('formats image with bytes source as inlineData', () => { + const imageBlock = new ImageBlock({ + format: 'png', + source: { bytes: new Uint8Array([0x89, 0x50, 0x4e, 0x47]) }, + }) + + const contents = formatBlock(imageBlock) + + expect(contents).toHaveLength(1) + expect(contents[0]!.parts).toEqual([{ inlineData: { data: 'iVBORw==', mimeType: 'image/png' } }]) + }) + + it('formats image with URL source as fileData', () => { + const imageBlock = new ImageBlock({ + format: 'jpeg', + source: { url: 'https://example.com/image.jpg' }, + }) + + const contents = formatBlock(imageBlock) + + expect(contents).toHaveLength(1) + expect(contents[0]!.parts).toEqual([ + { fileData: { fileUri: 'https://example.com/image.jpg', mimeType: 'image/jpeg' } }, + ]) + }) + + it('skips image with S3 source and logs warning', () => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + + const imageBlock = new ImageBlock({ + format: 'png', + source: { location: { type: 's3', uri: 's3://test/image.png' } }, + }) + + const contents = formatBlock(imageBlock) + + // Message with no valid parts is not included + expect(contents).toHaveLength(0) + expect(warnSpy).toHaveBeenCalled() + warnSpy.mockRestore() + }) + }) + + describe('document content', () => { + it('formats document with bytes source as inlineData', () => { + const docBlock = new DocumentBlock({ + name: 'test.pdf', + format: 'pdf', + source: { bytes: new Uint8Array([0x25, 0x50, 0x44, 0x46]) }, + }) + + const contents = formatBlock(docBlock) + + expect(contents).toHaveLength(1) + expect(contents[0]!.parts).toEqual([{ inlineData: { data: 'JVBERg==', mimeType: 'application/pdf' } }]) + }) + + it('formats document with text source as inlineData bytes', () => { + const docBlock = new DocumentBlock({ + name: 'test.txt', + format: 'txt', + source: { text: 'Document content here' }, + }) + + const contents = formatBlock(docBlock) + + expect(contents).toHaveLength(1) + expect(contents[0]!.parts).toEqual([ + { inlineData: { data: 'RG9jdW1lbnQgY29udGVudCBoZXJl', mimeType: 'text/plain' } }, + ]) + }) + + it('formats document with content block source as separate text parts', () => { + const docBlock = new DocumentBlock({ + name: 'test.txt', + format: 'txt', + source: { content: [{ text: 'Line 1' }, { text: 'Line 2' }] }, + }) + + const contents = formatBlock(docBlock) + + expect(contents).toHaveLength(1) + expect(contents[0]!.parts).toEqual([{ text: 'Line 1' }, { text: 'Line 2' }]) + }) + }) + + describe('video content', () => { + it('formats video with bytes source as inlineData', () => { + const videoBlock = new VideoBlock({ + format: 'mp4', + source: { bytes: new Uint8Array([0x00, 0x00, 0x00, 0x1c]) }, + }) + + const contents = formatBlock(videoBlock) + + expect(contents).toHaveLength(1) + expect(contents[0]!.parts).toEqual([{ inlineData: { data: 'AAAAHA==', mimeType: 'video/mp4' } }]) + }) + }) + + describe('reasoning content', () => { + it('formats reasoning block with thought flag', () => { + const reasoningBlock = new ReasoningBlock({ text: 'Let me think about this...' }) + + const contents = formatBlock(reasoningBlock, 'assistant') + + expect(contents).toHaveLength(1) + expect(contents[0]!.parts).toEqual([{ text: 'Let me think about this...', thought: true }]) + }) + + it('includes thought signature when present', () => { + const reasoningBlock = new ReasoningBlock({ text: 'Thinking...', signature: 'sig123' }) + + const contents = formatBlock(reasoningBlock, 'assistant') + + expect(contents).toHaveLength(1) + expect(contents[0]!.parts).toEqual([{ text: 'Thinking...', thought: true, thoughtSignature: 'sig123' }]) + }) + + it('skips reasoning block with empty text', () => { + const reasoningBlock = new ReasoningBlock({ text: '' }) + + const contents = formatBlock(reasoningBlock, 'assistant') + + expect(contents).toHaveLength(0) + }) + }) + + describe('unsupported content types', () => { + it.each([ + { name: 'cache point', block: new CachePointBlock({ cacheType: 'default' }) }, + { + name: 'guard content', + block: new GuardContentBlock({ text: { qualifiers: ['guard_content'], text: 'test' } }), + }, + ])('skips $name blocks with warning', ({ block }) => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + + const contents = formatBlock(block) + + expect(contents).toHaveLength(0) + warnSpy.mockRestore() + }) + + it('formats tool use blocks as function calls', () => { + const toolUseBlock = new ToolUseBlock({ toolUseId: 'test-id', name: 'testTool', input: { key: 'value' } }) + + const contents = formatBlock(toolUseBlock, 'assistant') + + expect(contents).toHaveLength(1) + expect(contents[0]!.parts).toEqual([ + { functionCall: { id: 'test-id', name: 'testTool', args: { key: 'value' } } }, + ]) + }) + }) + }) + + describe('reasoning content streaming', () => { + it('emits reasoning content delta events for thought parts', async () => { + const { provider, messages } = setupStreamTest(async function* () { + yield { + candidates: [ + { + content: { parts: [{ text: 'Thinking...', thought: true }] }, + }, + ], + } + yield { candidates: [{ finishReason: 'STOP' }] } + }) + + const events = await collectIterator(provider.stream(messages)) + + expect(events).toHaveLength(5) + expect(events[0]).toEqual({ type: 'modelMessageStartEvent', role: 'assistant' }) + expect(events[1]).toEqual({ type: 'modelContentBlockStartEvent' }) + expect(events[2]).toEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'reasoningContentDelta', text: 'Thinking...' }, + }) + expect(events[3]).toEqual({ type: 'modelContentBlockStopEvent' }) + expect(events[4]).toEqual({ type: 'modelMessageStopEvent', stopReason: 'endTurn' }) + }) + + it('handles transition from reasoning to text content', async () => { + const { provider, messages } = setupStreamTest(async function* () { + yield { + candidates: [ + { + content: { parts: [{ text: 'Let me think...', thought: true }] }, + }, + ], + } + yield { + candidates: [ + { + content: { parts: [{ text: 'Here is my answer' }] }, + }, + ], + } + yield { candidates: [{ finishReason: 'STOP' }] } + }) + + const events = await collectIterator(provider.stream(messages)) + + // Should have: messageStart, blockStart (reasoning), delta (reasoning), blockStop, + // blockStart (text), delta (text), blockStop, messageStop + expect(events).toHaveLength(8) + + // Reasoning block + expect(events[1]).toEqual({ type: 'modelContentBlockStartEvent' }) + expect(events[2]).toEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'reasoningContentDelta', text: 'Let me think...' }, + }) + expect(events[3]).toEqual({ type: 'modelContentBlockStopEvent' }) + + // Text block + expect(events[4]).toEqual({ type: 'modelContentBlockStartEvent' }) + expect(events[5]).toEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'Here is my answer' }, + }) + expect(events[6]).toEqual({ type: 'modelContentBlockStopEvent' }) + expect(events[7]).toEqual({ type: 'modelMessageStopEvent', stopReason: 'endTurn' }) + }) + + it('includes signature in reasoning delta when present', async () => { + const { provider, messages } = setupStreamTest(async function* () { + yield { + candidates: [ + { + content: { + parts: [ + { + text: 'Thinking...', + thought: true, + thoughtSignature: 'sig456', + }, + ], + }, + }, + ], + } + yield { candidates: [{ finishReason: 'STOP' }] } + }) + + const events = await collectIterator(provider.stream(messages)) + + const deltaEvent = events.find( + (e) => e.type === 'modelContentBlockDeltaEvent' && e.delta.type === 'reasoningContentDelta' + ) + expect(deltaEvent).toEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'reasoningContentDelta', text: 'Thinking...', signature: 'sig456' }, + }) + }) + }) + + describe('tool configuration', () => { + it('passes tool specs as functionDeclarations', async () => { + const { provider, captured, messages } = setupCaptureTest() + + await collectIterator( + provider.stream(messages, { + toolSpecs: [ + { + name: 'get_weather', + description: 'Get the weather', + inputSchema: { type: 'object', properties: { city: { type: 'string' } } }, + }, + ], + }) + ) + + const config = captured.config as { tools?: unknown[] } + expect(config.tools).toEqual([ + { + functionDeclarations: [ + { + name: 'get_weather', + description: 'Get the weather', + parametersJsonSchema: { type: 'object', properties: { city: { type: 'string' } } }, + }, + ], + }, + ]) + }) + + it.each([ + { + name: 'auto to AUTO', + toolChoice: { auto: {} }, + expectedMode: FunctionCallingConfigMode.AUTO, + }, + { + name: 'any to ANY', + toolChoice: { any: {} }, + expectedMode: FunctionCallingConfigMode.ANY, + }, + { + name: 'tool to ANY with allowedFunctionNames', + toolChoice: { tool: { name: 'get_weather' } }, + expectedMode: FunctionCallingConfigMode.ANY, + expectedAllowedFunctionNames: ['get_weather'], + }, + ])('maps toolChoice $name', async ({ toolChoice, expectedMode, expectedAllowedFunctionNames }) => { + const { provider, captured, messages } = setupCaptureTest() + + await collectIterator( + provider.stream(messages, { + toolSpecs: [{ name: 'get_weather', description: 'test' }], + toolChoice, + }) + ) + + const config = captured.config as { + toolConfig?: { functionCallingConfig?: { mode?: string; allowedFunctionNames?: string[] } } + } + expect(config.toolConfig?.functionCallingConfig?.mode).toBe(expectedMode) + if (expectedAllowedFunctionNames) { + expect(config.toolConfig?.functionCallingConfig?.allowedFunctionNames).toEqual(expectedAllowedFunctionNames) + } + }) + + it('does not add tools config when no toolSpecs provided', async () => { + const { provider, captured, messages } = setupCaptureTest() + + await collectIterator(provider.stream(messages)) + + const config = captured.config as { tools?: unknown } + expect(config.tools).toBeUndefined() + }) + }) + + describe('built-in tools', () => { + it('appends builtInTools to config.tools alongside functionDeclarations', async () => { + const { client, captured } = createMockClientWithCapture() + const provider = new GoogleModel({ client, builtInTools: [{ googleSearch: {} }] }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await collectIterator( + provider.stream(messages, { + toolSpecs: [ + { + name: 'get_weather', + description: 'Get the weather', + inputSchema: { type: 'object', properties: { city: { type: 'string' } } }, + }, + ], + }) + ) + + const config = captured.config as { tools?: unknown[] } + expect(config.tools).toHaveLength(2) + expect(config.tools![0]).toEqual({ + functionDeclarations: [ + { + name: 'get_weather', + description: 'Get the weather', + parametersJsonSchema: { type: 'object', properties: { city: { type: 'string' } } }, + }, + ], + }) + expect(config.tools![1]).toEqual({ googleSearch: {} }) + }) + + it('passes builtInTools when no toolSpecs provided', async () => { + const { client, captured } = createMockClientWithCapture() + const provider = new GoogleModel({ client, builtInTools: [{ codeExecution: {} }] }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await collectIterator(provider.stream(messages)) + + const config = captured.config as { tools?: unknown[] } + expect(config.tools).toHaveLength(1) + expect(config.tools![0]).toEqual({ codeExecution: {} }) + }) + + it('does not add tools when neither builtInTools nor toolSpecs provided', async () => { + const { client, captured } = createMockClientWithCapture() + const provider = new GoogleModel({ client }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await collectIterator(provider.stream(messages)) + + const config = captured.config as { tools?: unknown } + expect(config.tools).toBeUndefined() + }) + }) + + describe('tool use formatting', () => { + it('formats toolUseBlock with reasoningSignature as thoughtSignature', () => { + const toolUseBlock = new ToolUseBlock({ + toolUseId: 'test-id', + name: 'testTool', + input: { key: 'value' }, + reasoningSignature: 'sig789', + }) + + const contents = formatBlock(toolUseBlock, 'assistant') + + expect(contents).toHaveLength(1) + expect(contents[0]!.parts).toEqual([ + { + functionCall: { id: 'test-id', name: 'testTool', args: { key: 'value' } }, + thoughtSignature: 'sig789', + }, + ]) + }) + + it('formats toolResultBlock as functionResponse', () => { + const toolUseBlock = new ToolUseBlock({ toolUseId: 'test-id', name: 'testTool', input: {} }) + const toolResultBlock = new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('result text')], + }) + const messages = [ + new Message({ role: 'assistant', content: [toolUseBlock] }), + new Message({ role: 'user', content: [toolResultBlock] }), + ] + + const contents = formatMessages(messages) + + expect(contents).toHaveLength(2) + expect(contents[1]!.parts![0]).toEqual({ + functionResponse: { + id: 'test-id', + name: 'testTool', + response: { output: [{ text: 'result text' }] }, + }, + }) + }) + + it('resolves tool name from toolUseId in toolResultBlock', () => { + const toolUseBlock = new ToolUseBlock({ toolUseId: 'abc-123', name: 'my_tool', input: {} }) + const toolResultBlock = new ToolResultBlock({ + toolUseId: 'abc-123', + status: 'success', + content: [new TextBlock('ok')], + }) + const messages = [ + new Message({ role: 'assistant', content: [toolUseBlock] }), + new Message({ role: 'user', content: [toolResultBlock] }), + ] + + const contents = formatMessages(messages) + + const resultPart = contents[1]!.parts![0]! + const fr = (resultPart as { functionResponse: { name: string } }).functionResponse + expect(fr.name).toBe('my_tool') + }) + + it('falls back to toolUseId when tool name mapping is not found', () => { + const toolResultBlock = new ToolResultBlock({ + toolUseId: 'unknown-id', + status: 'success', + content: [new TextBlock('ok')], + }) + const messages = [new Message({ role: 'user', content: [toolResultBlock] })] + + const contents = formatMessages(messages) + + const resultPart = contents[0]!.parts![0]! + const fr = (resultPart as { functionResponse: { name: string } }).functionResponse + expect(fr.name).toBe('unknown-id') + }) + + it('formats image block in tool result as inlineData', () => { + const imageBytes = new Uint8Array([0x89, 0x50, 0x4e, 0x47]) + const toolUseBlock = new ToolUseBlock({ toolUseId: 'test-id', name: 'screenshot', input: {} }) + const toolResultBlock = new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new ImageBlock({ format: 'png', source: { bytes: imageBytes } })], + }) + const messages = [ + new Message({ role: 'assistant', content: [toolUseBlock] }), + new Message({ role: 'user', content: [toolResultBlock] }), + ] + + const contents = formatMessages(messages) + + const resultPart = contents[1]!.parts![0]! as { functionResponse: { response: unknown; parts?: unknown[] } } + // Image goes to separate parts, not into response.output + expect(resultPart.functionResponse.response).toEqual({ output: [] }) + expect(resultPart.functionResponse.parts).toEqual([ + { inlineData: { data: 'iVBORw==', mimeType: 'image/png', displayName: 'image.png' } }, + ]) + }) + + it('formats document block with bytes source in tool result as inlineData', () => { + const docBytes = new Uint8Array([0x25, 0x50, 0x44, 0x46]) + const toolUseBlock = new ToolUseBlock({ toolUseId: 'test-id', name: 'read_doc', input: {} }) + const toolResultBlock = new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new DocumentBlock({ name: 'report.pdf', format: 'pdf', source: { bytes: docBytes } })], + }) + const messages = [ + new Message({ role: 'assistant', content: [toolUseBlock] }), + new Message({ role: 'user', content: [toolResultBlock] }), + ] + + const contents = formatMessages(messages) + + const resultPart = contents[1]!.parts![0]! as { functionResponse: { response: unknown; parts?: unknown[] } } + expect(resultPart.functionResponse.response).toEqual({ output: [] }) + expect(resultPart.functionResponse.parts).toEqual([ + { inlineData: { data: 'JVBERg==', mimeType: 'application/pdf', displayName: 'report.pdf' } }, + ]) + }) + + it('formats document block with text source in tool result as inlineData', () => { + const toolUseBlock = new ToolUseBlock({ toolUseId: 'test-id', name: 'read_doc', input: {} }) + const toolResultBlock = new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new DocumentBlock({ name: 'notes.txt', format: 'txt', source: { text: 'Hello' } })], + }) + const messages = [ + new Message({ role: 'assistant', content: [toolUseBlock] }), + new Message({ role: 'user', content: [toolResultBlock] }), + ] + + const contents = formatMessages(messages) + + const resultPart = contents[1]!.parts![0]! as { functionResponse: { response: unknown; parts?: unknown[] } } + expect(resultPart.functionResponse.response).toEqual({ output: [] }) + expect(resultPart.functionResponse.parts).toEqual([ + { inlineData: { data: 'SGVsbG8=', mimeType: 'text/plain', displayName: 'notes.txt' } }, + ]) + }) + + it('skips video block in tool result with warning', () => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const toolUseBlock = new ToolUseBlock({ toolUseId: 'test-id', name: 'capture', input: {} }) + const toolResultBlock = new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [new TextBlock('captured'), new VideoBlock({ format: 'mp4', source: { bytes: new Uint8Array([1]) } })], + }) + const messages = [ + new Message({ role: 'assistant', content: [toolUseBlock] }), + new Message({ role: 'user', content: [toolResultBlock] }), + ] + + const contents = formatMessages(messages) + + const resultPart = contents[1]!.parts![0]! as { functionResponse: { response: unknown; parts?: unknown[] } } + expect(resultPart.functionResponse.response).toEqual({ output: [{ text: 'captured' }] }) + // No parts for video - it's skipped + expect(resultPart.functionResponse.parts).toBeUndefined() + expect(warnSpy).toHaveBeenCalled() + warnSpy.mockRestore() + }) + + it('formats mixed text and image content in tool result', () => { + const imageBytes = new Uint8Array([1, 2]) + const toolUseBlock = new ToolUseBlock({ toolUseId: 'test-id', name: 'analyze', input: {} }) + const toolResultBlock = new ToolResultBlock({ + toolUseId: 'test-id', + status: 'success', + content: [ + new TextBlock('Analysis complete'), + new ImageBlock({ format: 'jpeg', source: { bytes: imageBytes } }), + ], + }) + const messages = [ + new Message({ role: 'assistant', content: [toolUseBlock] }), + new Message({ role: 'user', content: [toolResultBlock] }), + ] + + const contents = formatMessages(messages) + + const resultPart = contents[1]!.parts![0]! as { functionResponse: { response: unknown; parts?: unknown[] } } + expect(resultPart.functionResponse.response).toEqual({ output: [{ text: 'Analysis complete' }] }) + expect(resultPart.functionResponse.parts).toEqual([ + { inlineData: { data: 'AQI=', mimeType: 'image/jpeg', displayName: 'image.jpeg' } }, + ]) + }) + }) + + describe('tool use streaming', () => { + function createStreamState(): GoogleStreamState { + return { + messageStarted: true, + textContentBlockStarted: false, + reasoningContentBlockStarted: false, + hasToolCalls: false, + inputTokens: 0, + outputTokens: 0, + } + } + + it('emits tool use events for function call in response', () => { + const streamState = createStreamState() + const chunk = { + candidates: [ + { + content: { + parts: [{ functionCall: { id: 'tool-1', name: 'get_weather', args: { city: 'NYC' } } }], + }, + }, + ], + } + + const events = mapChunkToEvents(chunk as unknown as GenerateContentResponse, streamState) + + expect(events).toEqual([ + { + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', name: 'get_weather', toolUseId: 'tool-1' }, + }, + { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: '{"city":"NYC"}' }, + }, + { type: 'modelContentBlockStopEvent' }, + ]) + expect(streamState.hasToolCalls).toBe(true) + }) + + it('generates tool use ID when Gemini does not provide one', () => { + const streamState = createStreamState() + const chunk = { + candidates: [ + { + content: { + parts: [{ functionCall: { name: 'testTool', args: {} } }], + }, + }, + ], + } + + const events = mapChunkToEvents(chunk as unknown as GenerateContentResponse, streamState) + + const startEvent = events[0]! + expect(startEvent.type).toBe('modelContentBlockStartEvent') + const start = (startEvent as { start: { toolUseId: string } }).start + expect(start.toolUseId).toMatch(/^tooluse_/) + }) + + it('includes reasoningSignature from thoughtSignature on function call', () => { + const streamState = createStreamState() + const chunk = { + candidates: [ + { + content: { + parts: [ + { + functionCall: { id: 'tool-1', name: 'testTool', args: {} }, + thoughtSignature: 'sig-abc', + }, + ], + }, + }, + ], + } + + const events = mapChunkToEvents(chunk as unknown as GenerateContentResponse, streamState) + + const startEvent = events[0]! + const start = (startEvent as { start: { reasoningSignature: string } }).start + expect(start.reasoningSignature).toBe('sig-abc') + }) + + it('sets stop reason to toolUse when function calls are present', () => { + const streamState = createStreamState() + streamState.hasToolCalls = true + + const chunk = { + candidates: [{ finishReason: 'STOP' }], + } + + const events = mapChunkToEvents(chunk as unknown as GenerateContentResponse, streamState) + + expect(events).toEqual([{ type: 'modelMessageStopEvent', stopReason: 'toolUse' }]) + }) + + it.each([ + { blockType: 'reasoning', stateField: 'reasoningContentBlockStarted' as const }, + { blockType: 'text', stateField: 'textContentBlockStarted' as const }, + ])('closes $blockType block before tool use block', ({ stateField }) => { + const streamState = createStreamState() + streamState[stateField] = true + + const chunk = { + candidates: [ + { + content: { + parts: [{ functionCall: { id: 'tool-1', name: 'testTool', args: {} } }], + }, + }, + ], + } + + const events = mapChunkToEvents(chunk as unknown as GenerateContentResponse, streamState) + + expect(events[0]).toEqual({ type: 'modelContentBlockStopEvent' }) + expect(events[1]).toEqual({ + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', name: 'testTool', toolUseId: 'tool-1' }, + }) + expect(streamState[stateField]).toBe(false) + }) + + it('handles multiple function calls in a single response', () => { + const streamState = createStreamState() + const chunk = { + candidates: [ + { + content: { + parts: [ + { functionCall: { id: 'tool-1', name: 'get_weather', args: { city: 'NYC' } } }, + { functionCall: { id: 'tool-2', name: 'get_time', args: { tz: 'EST' } } }, + ], + }, + }, + ], + } + + const events = mapChunkToEvents(chunk as unknown as GenerateContentResponse, streamState) + + // Each function call: start + delta + stop = 3 events, x2 = 6 + expect(events).toHaveLength(6) + expect(events[0]).toEqual({ + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', name: 'get_weather', toolUseId: 'tool-1' }, + }) + expect(events[3]).toEqual({ + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', name: 'get_time', toolUseId: 'tool-2' }, + }) + }) + + it('handles full tool use flow via stream method', async () => { + const { provider, messages } = setupStreamTest(async function* () { + yield { + candidates: [ + { + content: { + parts: [{ functionCall: { id: 'call-1', name: 'get_weather', args: { city: 'NYC' } } }], + }, + }, + ], + } + yield { candidates: [{ finishReason: 'STOP' }] } + }) + + const events = await collectIterator(provider.stream(messages)) + + // messageStart, blockStart (toolUse), delta (toolUseInput), blockStop, messageStop + expect(events).toHaveLength(5) + expect(events[0]).toEqual({ type: 'modelMessageStartEvent', role: 'assistant' }) + expect(events[1]).toEqual({ + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', name: 'get_weather', toolUseId: 'call-1' }, + }) + expect(events[2]).toEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: '{"city":"NYC"}' }, + }) + expect(events[3]).toEqual({ type: 'modelContentBlockStopEvent' }) + expect(events[4]).toEqual({ type: 'modelMessageStopEvent', stopReason: 'toolUse' }) + }) + }) + + describe('countTokens', () => { + const messages: Message[] = [new Message({ role: 'user', content: [new TextBlock('hello')] })] + const toolSpecs = [ + { name: 'test_tool', description: 'A test tool', inputSchema: { type: 'object' as const, properties: {} } }, + ] + + function createCountTokensClient(mockCountTokens: ReturnType): GoogleGenAI { + return { + models: { + generateContentStream: vi.fn(), + countTokens: mockCountTokens, + }, + } as unknown as GoogleGenAI + } + + it('should use heuristic by default when useNativeTokenCount is not set', async () => { + const mockCountTokens = vi.fn() + const client = createCountTokensClient(mockCountTokens) + const model = new GoogleModel({ client, modelId: 'gemini-2.5-flash' }) + + const result = await model.countTokens(messages) + + expect(mockCountTokens).not.toHaveBeenCalled() + expect(result).toBe(2) // heuristic: Math.ceil('hello'.length / 4) + }) + + it('should return native token count on success', async () => { + const mockCountTokens = vi.fn(async () => ({ totalTokens: 42 })) + const client = createCountTokensClient(mockCountTokens) + const model = new GoogleModel({ client, modelId: 'gemini-2.5-flash', useNativeTokenCount: true }) + + const result = await model.countTokens(messages) + + expect(result).toBe(42) + expect(mockCountTokens).toHaveBeenCalledOnce() + }) + + it('should add heuristic estimate for system prompt', async () => { + const mockCountTokens = vi.fn(async () => ({ totalTokens: 55 })) + const client = createCountTokensClient(mockCountTokens) + const model = new GoogleModel({ client, modelId: 'gemini-2.5-flash', useNativeTokenCount: true }) + + const result = await model.countTokens(messages, { systemPrompt: 'Be helpful.' }) + + expect(result).toBeGreaterThan(55) // native (55) + heuristic for system prompt + }) + + it('should add heuristic estimate for tool specs', async () => { + const mockCountTokens = vi.fn(async () => ({ totalTokens: 100 })) + const client = createCountTokensClient(mockCountTokens) + const model = new GoogleModel({ client, modelId: 'gemini-2.5-flash', useNativeTokenCount: true }) + + const result = await model.countTokens(messages, { toolSpecs }) + + expect(result).toBeGreaterThan(100) // native (100) + heuristic for tools + }) + + it('should fall back on null totalTokens', async () => { + const mockCountTokens = vi.fn(async () => ({ totalTokens: null })) + const client = createCountTokensClient(mockCountTokens) + const model = new GoogleModel({ client, modelId: 'gemini-2.5-flash', useNativeTokenCount: true }) + + const result = await model.countTokens(messages) + + expect(typeof result).toBe('number') + expect(result).toBeGreaterThanOrEqual(0) + }) + + it('should fall back to estimation on API error', async () => { + const mockCountTokens = vi.fn(async () => { + throw new Error('Unsupported') + }) + const client = createCountTokensClient(mockCountTokens) + const model = new GoogleModel({ client, modelId: 'gemini-2.5-flash', useNativeTokenCount: true }) + + const result = await model.countTokens(messages) + + expect(typeof result).toBe('number') + expect(result).toBeGreaterThanOrEqual(0) + }) + + it('should fall back to estimation on generic exception', async () => { + const mockCountTokens = vi.fn(async () => { + throw new Error('Connection failed') + }) + const client = createCountTokensClient(mockCountTokens) + const model = new GoogleModel({ client, modelId: 'gemini-2.5-flash', useNativeTokenCount: true }) + + const result = await model.countTokens(messages) + + expect(typeof result).toBe('number') + expect(result).toBeGreaterThanOrEqual(0) + }) + + it('should skip native API and use heuristic when useNativeTokenCount is false', async () => { + const mockCountTokens = vi.fn() + const client = createCountTokensClient(mockCountTokens) + const model = new GoogleModel({ client, modelId: 'gemini-2.5-flash', useNativeTokenCount: false }) + + const result = await model.countTokens(messages) + + expect(mockCountTokens).not.toHaveBeenCalled() + expect(result).toBe(2) // heuristic: Math.ceil('hello'.length / 4) + }) + }) +}) diff --git a/strands-ts/src/models/__tests__/model.test.ts b/strands-ts/src/models/__tests__/model.test.ts new file mode 100644 index 0000000000..3530e68df6 --- /dev/null +++ b/strands-ts/src/models/__tests__/model.test.ts @@ -0,0 +1,1016 @@ +import { describe, it, expect } from 'vitest' +import { + Message, + TextBlock, + ToolUseBlock, + ToolResultBlock, + ReasoningBlock, + GuardContentBlock, +} from '../../types/messages.js' +import { CitationsBlock } from '../../types/citations.js' +import { TestModelProvider, collectGenerator } from '../../__fixtures__/model-test-helpers.js' +import { MaxTokensError, ModelError } from '../../errors.js' +import { Model } from '../model.js' +import type { BaseModelConfig, StreamOptions } from '../model.js' +import type { ModelStreamEvent } from '../streaming.js' + +/** + * Test model provider that throws an error from stream(). + */ +class ErrorThrowingModelProvider extends Model { + private config: BaseModelConfig = { modelId: 'test-model' } + private errorToThrow: Error + + constructor(errorToThrow: Error) { + super() + this.errorToThrow = errorToThrow + } + + updateConfig(modelConfig: BaseModelConfig): void { + this.config = { ...this.config, ...modelConfig } + } + + getConfig(): BaseModelConfig { + return this.config + } + + // eslint-disable-next-line require-yield + async *stream(_messages: Message[], _options?: StreamOptions): AsyncGenerator { + throw this.errorToThrow + } +} + +describe('Model', () => { + describe('streamAggregated', () => { + describe('when streaming a simple text message', () => { + it('yields original events plus aggregated content block and returns final message', async () => { + const provider = new TestModelProvider(async function* () { + yield { type: 'modelMessageStartEvent', role: 'assistant' } + yield { type: 'modelContentBlockStartEvent' } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'Hello' }, + } + yield { type: 'modelContentBlockStopEvent' } + yield { type: 'modelMessageStopEvent', stopReason: 'endTurn' } + yield { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + } + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const { items, result } = await collectGenerator(provider.streamAggregated(messages)) + + // Verify all yielded items (events + aggregated content block + metadata) + expect(items).toEqual([ + { type: 'modelMessageStartEvent', role: 'assistant' }, + { type: 'modelContentBlockStartEvent' }, + { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'Hello' }, + }, + { type: 'modelContentBlockStopEvent' }, + { type: 'textBlock', text: 'Hello' }, + { type: 'modelMessageStopEvent', stopReason: 'endTurn' }, + { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + }, + ]) + + // Verify the returned result includes metadata + expect(result).toEqual({ + message: { + type: 'message', + role: 'assistant', + content: [{ type: 'textBlock', text: 'Hello' }], + metadata: { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + }, + }, + stopReason: 'endTurn', + metadata: { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + }, + }) + }) + + it('throws MaxTokenError when stopReason is MaxTokenError', async () => { + const provider = new TestModelProvider(async function* () { + yield { type: 'modelMessageStartEvent', role: 'assistant' } + yield { type: 'modelContentBlockStartEvent' } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'Hello' }, + } + yield { type: 'modelContentBlockStopEvent' } + yield { type: 'modelMessageStopEvent', stopReason: 'maxTokens' } + yield { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + } + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(async () => await collectGenerator(provider.streamAggregated(messages))).rejects.toThrow( + 'Model reached maximum token limit. This is an unrecoverable state that requires intervention.' + ) + }) + }) + + describe('when streaming multiple text blocks', () => { + it('yields all blocks in order', async () => { + const provider = new TestModelProvider(async function* () { + yield { type: 'modelMessageStartEvent', role: 'assistant' } + yield { type: 'modelContentBlockStartEvent' } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'First' }, + } + yield { type: 'modelContentBlockStopEvent' } + yield { type: 'modelContentBlockStartEvent' } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'Second' }, + } + yield { type: 'modelContentBlockStopEvent' } + yield { type: 'modelMessageStopEvent', stopReason: 'endTurn' } + yield { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 10, totalTokens: 20 }, + } + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const { items, result } = await collectGenerator(provider.streamAggregated(messages)) + + expect(items).toContainEqual({ type: 'textBlock', text: 'First' }) + expect(items).toContainEqual({ type: 'textBlock', text: 'Second' }) + expect(items).toContainEqual({ + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 10, totalTokens: 20 }, + }) + + expect(result).toEqual({ + message: { + type: 'message', + role: 'assistant', + content: [ + { type: 'textBlock', text: 'First' }, + { type: 'textBlock', text: 'Second' }, + ], + metadata: { + usage: { inputTokens: 10, outputTokens: 10, totalTokens: 20 }, + }, + }, + stopReason: 'endTurn', + metadata: { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 10, totalTokens: 20 }, + }, + }) + }) + }) + + describe('when streaming tool use', () => { + it('yields complete tool use block', async () => { + const provider = new TestModelProvider(async function* () { + yield { type: 'modelMessageStartEvent', role: 'assistant' } + yield { + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', toolUseId: 'tool1', name: 'get_weather' }, + } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: '{"location"' }, + } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: ': "Paris"}' }, + } + yield { type: 'modelContentBlockStopEvent' } + yield { type: 'modelMessageStopEvent', stopReason: 'toolUse' } + yield { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 8, totalTokens: 18 }, + } + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const { items, result } = await collectGenerator(provider.streamAggregated(messages)) + + expect(items).toContainEqual({ + type: 'toolUseBlock', + toolUseId: 'tool1', + name: 'get_weather', + input: { location: 'Paris' }, + }) + expect(items).toContainEqual({ + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 8, totalTokens: 18 }, + }) + + expect(result).toEqual({ + message: { + type: 'message', + role: 'assistant', + content: [ + { + type: 'toolUseBlock', + toolUseId: 'tool1', + name: 'get_weather', + input: { location: 'Paris' }, + }, + ], + metadata: { + usage: { inputTokens: 10, outputTokens: 8, totalTokens: 18 }, + }, + }, + stopReason: 'toolUse', + metadata: { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 8, totalTokens: 18 }, + }, + }) + }) + + it('yields complete tool use block with empty input', async () => { + const provider = new TestModelProvider(async function* () { + yield { type: 'modelMessageStartEvent', role: 'assistant' } + yield { + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', toolUseId: 'tool1', name: 'get_time' }, + } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: '' }, + } + yield { type: 'modelContentBlockStopEvent' } + yield { type: 'modelMessageStopEvent', stopReason: 'toolUse' } + yield { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 8, totalTokens: 18 }, + } + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const { items, result } = await collectGenerator(provider.streamAggregated(messages)) + + expect(items).toContainEqual({ + type: 'toolUseBlock', + toolUseId: 'tool1', + name: 'get_time', + input: {}, + }) + expect(items).toContainEqual({ + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 8, totalTokens: 18 }, + }) + + expect(result).toEqual({ + message: { + type: 'message', + role: 'assistant', + content: [ + { + type: 'toolUseBlock', + toolUseId: 'tool1', + name: 'get_time', + input: {}, + }, + ], + metadata: { + usage: { inputTokens: 10, outputTokens: 8, totalTokens: 18 }, + }, + }, + stopReason: 'toolUse', + metadata: { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 8, totalTokens: 18 }, + }, + }) + }) + + it('throws MaxTokenError when stopReason is MaxTokenError and toolUse is partial', async () => { + const provider = new TestModelProvider(async function* () { + yield { type: 'modelMessageStartEvent', role: 'assistant' } + yield { + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', toolUseId: 'tool1', name: 'get_weather' }, + } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: '{"location"' }, + } + yield { type: 'modelMessageStopEvent', stopReason: 'maxTokens' } + yield { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 8, totalTokens: 18 }, + } + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(async () => await collectGenerator(provider.streamAggregated(messages))).rejects.toThrow( + MaxTokensError + ) + }) + + it('preserves SyntaxError instead of overwriting with MaxTokensError when tool input JSON is malformed', async () => { + const provider = new TestModelProvider(async function* () { + yield { type: 'modelMessageStartEvent', role: 'assistant' } + yield { + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', toolUseId: 'tool1', name: 'get_weather' }, + } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: '{invalid json' }, + } + yield { type: 'modelContentBlockStopEvent' } + yield { type: 'modelMessageStopEvent', stopReason: 'maxTokens' } + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + try { + await collectGenerator(provider.streamAggregated(messages)) + expect.fail('Expected error to be thrown') + } catch (error) { + expect(error).toBeInstanceOf(ModelError) + expect(error).not.toBeInstanceOf(MaxTokensError) + expect((error as ModelError).cause).toBeInstanceOf(SyntaxError) + } + }) + }) + + describe('when streaming reasoning content', () => { + it('yields complete reasoning block', async () => { + const provider = new TestModelProvider(async function* () { + yield { type: 'modelMessageStartEvent', role: 'assistant' } + yield { type: 'modelContentBlockStartEvent' } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'reasoningContentDelta', text: 'Thinking about', signature: 'sig1' }, + } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'reasoningContentDelta', text: ' the problem' }, + } + yield { type: 'modelContentBlockStopEvent' } + yield { type: 'modelMessageStopEvent', stopReason: 'endTurn' } + yield { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 10, totalTokens: 20 }, + } + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const { items, result } = await collectGenerator(provider.streamAggregated(messages)) + + expect(items).toContainEqual({ + type: 'reasoningBlock', + text: 'Thinking about the problem', + signature: 'sig1', + }) + expect(items).toContainEqual({ + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 10, totalTokens: 20 }, + }) + + expect(result).toEqual({ + message: { + type: 'message', + role: 'assistant', + content: [ + { + type: 'reasoningBlock', + text: 'Thinking about the problem', + signature: 'sig1', + }, + ], + metadata: { + usage: { inputTokens: 10, outputTokens: 10, totalTokens: 20 }, + }, + }, + stopReason: 'endTurn', + metadata: { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 10, totalTokens: 20 }, + }, + }) + }) + + it('yields redacted content reasoning block', async () => { + const provider = new TestModelProvider(async function* () { + yield { type: 'modelMessageStartEvent', role: 'assistant' } + yield { type: 'modelContentBlockStartEvent' } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'reasoningContentDelta', redactedContent: new Uint8Array(0) }, + } + yield { type: 'modelContentBlockStopEvent' } + yield { type: 'modelMessageStopEvent', stopReason: 'endTurn' } + yield { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + } + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const { items, result } = await collectGenerator(provider.streamAggregated(messages)) + + expect(items).toContainEqual({ + type: 'reasoningBlock', + redactedContent: new Uint8Array(0), + }) + expect(items).toContainEqual({ + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + }) + + expect(result).toEqual({ + message: { + type: 'message', + role: 'assistant', + content: [ + { + type: 'reasoningBlock', + redactedContent: new Uint8Array(0), + }, + ], + metadata: { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + }, + }, + stopReason: 'endTurn', + metadata: { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + }, + }) + }) + + it('omits signature if not present', async () => { + const provider = new TestModelProvider(async function* () { + yield { type: 'modelMessageStartEvent', role: 'assistant' } + yield { type: 'modelContentBlockStartEvent' } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'reasoningContentDelta', text: 'Thinking' }, + } + yield { type: 'modelContentBlockStopEvent' } + yield { type: 'modelMessageStopEvent', stopReason: 'endTurn' } + yield { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + } + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const { items, result } = await collectGenerator(provider.streamAggregated(messages)) + + expect(items).toContainEqual({ + type: 'reasoningBlock', + text: 'Thinking', + }) + expect(items).toContainEqual({ + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + }) + + expect(result).toEqual({ + message: { + type: 'message', + role: 'assistant', + content: [ + { + type: 'reasoningBlock', + text: 'Thinking', + }, + ], + metadata: { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + }, + }, + stopReason: 'endTurn', + metadata: { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + }, + }) + }) + }) + + describe('when streaming mixed content blocks', () => { + it('yields all blocks in correct order', async () => { + const provider = new TestModelProvider(async function* () { + yield { type: 'modelMessageStartEvent', role: 'assistant' } + yield { type: 'modelContentBlockStartEvent' } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'Hello' }, + } + yield { type: 'modelContentBlockStopEvent' } + yield { + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', toolUseId: 'tool1', name: 'get_weather' }, + } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: '{"city": "Paris"}' }, + } + yield { type: 'modelContentBlockStopEvent' } + yield { type: 'modelContentBlockStartEvent' } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'reasoningContentDelta', text: 'Reasoning', signature: 'sig1' }, + } + yield { type: 'modelContentBlockStopEvent' } + yield { type: 'modelMessageStopEvent', stopReason: 'endTurn' } + yield { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 15, totalTokens: 25 }, + } + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const { items, result } = await collectGenerator(provider.streamAggregated(messages)) + + expect(items).toContainEqual({ type: 'textBlock', text: 'Hello' }) + expect(items).toContainEqual({ + type: 'toolUseBlock', + toolUseId: 'tool1', + name: 'get_weather', + input: { city: 'Paris' }, + }) + expect(items).toContainEqual({ type: 'reasoningBlock', text: 'Reasoning', signature: 'sig1' }) + expect(items).toContainEqual({ + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 15, totalTokens: 25 }, + }) + + expect(result).toEqual({ + message: { + type: 'message', + role: 'assistant', + content: [ + { type: 'textBlock', text: 'Hello' }, + { type: 'toolUseBlock', toolUseId: 'tool1', name: 'get_weather', input: { city: 'Paris' } }, + { type: 'reasoningBlock', text: 'Reasoning', signature: 'sig1' }, + ], + metadata: { + usage: { inputTokens: 10, outputTokens: 15, totalTokens: 25 }, + }, + }, + stopReason: 'endTurn', + metadata: { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 15, totalTokens: 25 }, + }, + }) + }) + }) + + describe('when multiple metadata events are emitted', () => { + it('yields all metadata events but keeps only the last one in return value', async () => { + const provider = new TestModelProvider(async function* () { + yield { type: 'modelMessageStartEvent', role: 'assistant' } + yield { type: 'modelContentBlockStartEvent' } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'Hello' }, + } + yield { type: 'modelContentBlockStopEvent' } + yield { type: 'modelMessageStopEvent', stopReason: 'endTurn' } + yield { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + } + yield { + type: 'modelMetadataEvent', + usage: { inputTokens: 20, outputTokens: 10, totalTokens: 30 }, + metrics: { latencyMs: 100 }, + } + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const { items, result } = await collectGenerator(provider.streamAggregated(messages)) + + // Both metadata events should be yielded + expect(items).toContainEqual({ + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + }) + expect(items).toContainEqual({ + type: 'modelMetadataEvent', + usage: { inputTokens: 20, outputTokens: 10, totalTokens: 30 }, + metrics: { latencyMs: 100 }, + }) + + // Only the last metadata should be in return value + expect(result).toEqual({ + message: { + type: 'message', + role: 'assistant', + content: [{ type: 'textBlock', text: 'Hello' }], + metadata: { + usage: { inputTokens: 20, outputTokens: 10, totalTokens: 30 }, + metrics: { latencyMs: 100 }, + }, + }, + stopReason: 'endTurn', + metadata: { + type: 'modelMetadataEvent', + usage: { inputTokens: 20, outputTokens: 10, totalTokens: 30 }, + metrics: { latencyMs: 100 }, + }, + }) + }) + }) + + describe('when no metadata events are emitted', () => { + it('returns result with undefined metadata', async () => { + const provider = new TestModelProvider(async function* () { + yield { type: 'modelMessageStartEvent', role: 'assistant' } + yield { type: 'modelContentBlockStartEvent' } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'Hello' }, + } + yield { type: 'modelContentBlockStopEvent' } + yield { type: 'modelMessageStopEvent', stopReason: 'endTurn' } + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const { items, result } = await collectGenerator(provider.streamAggregated(messages)) + + // No metadata event should be in yielded items + expect(items.filter((item) => item.type === 'modelMetadataEvent')).toHaveLength(0) + + // Metadata should be undefined in return value + expect(result).toEqual({ + message: { + type: 'message', + role: 'assistant', + content: [{ type: 'textBlock', text: 'Hello' }], + }, + stopReason: 'endTurn', + metadata: undefined, + }) + }) + }) + + describe('when stream() throws an error', () => { + it('wraps non-ModelError errors in ModelError with original as cause', async () => { + const originalError = new Error('API connection failed') + const provider = new ErrorThrowingModelProvider(originalError) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + try { + await collectGenerator(provider.streamAggregated(messages)) + expect.fail('Expected error to be thrown') + } catch (error) { + expect(error).toBeInstanceOf(ModelError) + expect((error as ModelError).message).toBe('API connection failed') + expect((error as ModelError).cause).toBe(originalError) + } + }) + }) + + describe('when receiving redact content events', () => { + it('returns redaction.userMessage when inputRedaction is present', async () => { + const provider = new TestModelProvider(async function* () { + yield { type: 'modelMessageStartEvent', role: 'assistant' } + yield { type: 'modelContentBlockStartEvent' } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'Hello' }, + } + yield { type: 'modelContentBlockStopEvent' } + yield { type: 'modelMessageStopEvent', stopReason: 'guardrailIntervened' } + yield { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + } + yield { + type: 'modelRedactionEvent', + inputRedaction: { replaceContent: '[User input redacted.]' }, + } + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Sensitive content')] })] + + const { result } = await collectGenerator(provider.streamAggregated(messages)) + + // Verify redaction.userMessage is returned for agent to handle + expect(result.redaction?.userMessage).toBe('[User input redacted.]') + + // Messages array should NOT be modified (agent handles this) + expect(messages[0]!.content).toEqual([{ type: 'textBlock', text: 'Sensitive content' }]) + }) + + it('redacts assistant message directly when outputRedaction is present', async () => { + const provider = new TestModelProvider(async function* () { + yield { type: 'modelMessageStartEvent', role: 'assistant' } + yield { type: 'modelContentBlockStartEvent' } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'Harmful content' }, + } + yield { type: 'modelContentBlockStopEvent' } + yield { type: 'modelMessageStopEvent', stopReason: 'guardrailIntervened' } + yield { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + } + yield { + type: 'modelRedactionEvent', + outputRedaction: { replaceContent: '[Assistant output redacted.]' }, + } + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Tell me something')] })] + + const { result } = await collectGenerator(provider.streamAggregated(messages)) + + // Assistant message is redacted directly by the model + expect(result.message.role).toBe('assistant') + expect(result.message.content).toEqual([{ type: 'textBlock', text: '[Assistant output redacted.]' }]) + + // No redaction.userMessage since assistant redaction is handled directly + expect(result.redaction?.userMessage).toBeUndefined() + }) + + it('returns redactionMessage and redacts assistant when both are present', async () => { + const provider = new TestModelProvider(async function* () { + yield { type: 'modelMessageStartEvent', role: 'assistant' } + yield { type: 'modelContentBlockStartEvent' } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'Response' }, + } + yield { type: 'modelContentBlockStopEvent' } + yield { type: 'modelMessageStopEvent', stopReason: 'guardrailIntervened' } + yield { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + } + yield { + type: 'modelRedactionEvent', + inputRedaction: { replaceContent: '[User input redacted.]' }, + } + yield { + type: 'modelRedactionEvent', + outputRedaction: { replaceContent: '[Assistant output redacted.]' }, + } + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Input')] })] + + const { result } = await collectGenerator(provider.streamAggregated(messages)) + + // Verify redaction.userMessage is returned for agent to handle user redaction + expect(result.redaction?.userMessage).toBe('[User input redacted.]') + + // Assistant message is redacted directly + expect(result.message.role).toBe('assistant') + expect(result.message.content).toEqual([{ type: 'textBlock', text: '[Assistant output redacted.]' }]) + }) + + it('does not include redaction when no redact events are received', async () => { + const provider = new TestModelProvider(async function* () { + yield { type: 'modelMessageStartEvent', role: 'assistant' } + yield { type: 'modelContentBlockStartEvent' } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'Hello' }, + } + yield { type: 'modelContentBlockStopEvent' } + yield { type: 'modelMessageStopEvent', stopReason: 'endTurn' } + yield { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + } + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + const { result } = await collectGenerator(provider.streamAggregated(messages)) + + // Verify redaction.userMessage is undefined + expect(result.redaction?.userMessage).toBeUndefined() + }) + }) + }) +}) + +describe('Model.modelId', () => { + it('returns modelId from model config', () => { + const provider = new TestModelProvider() + provider.updateConfig({ modelId: 'my-model' }) + + expect(provider.modelId).toBe('my-model') + }) +}) + +describe('countTokens', () => { + it('estimates text block tokens using chars/4 heuristic', async () => { + const provider = new TestModelProvider() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello world')] })] + + const result = await provider.countTokens(messages) + + expect(result).toBe(3) + }) + + it('estimates toolUse block tokens (name + JSON input)', async () => { + const provider = new TestModelProvider() + const messages = [ + new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'get_weather', toolUseId: 'id1', input: { city: 'Seattle' } })], + }), + ] + + const result = await provider.countTokens(messages) + + expect(result).toBe(3 + 9) + }) + + it('estimates toolResult block tokens (text items only)', async () => { + const provider = new TestModelProvider() + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'id1', + status: 'success', + content: [new TextBlock('72°F and sunny')], + }), + ], + }), + ] + + const result = await provider.countTokens(messages) + + expect(result).toBe(Math.ceil('72°F and sunny'.length / 4)) + }) + + it('estimates reasoning block tokens', async () => { + const provider = new TestModelProvider() + const messages = [ + new Message({ + role: 'assistant', + content: [new ReasoningBlock({ text: 'Let me think about this step by step' })], + }), + ] + + const result = await provider.countTokens(messages) + + expect(result).toBe(Math.ceil('Let me think about this step by step'.length / 4)) + }) + + it('estimates guardContent block tokens', async () => { + const provider = new TestModelProvider() + const messages = [ + new Message({ + role: 'user', + content: [ + new GuardContentBlock({ + text: { qualifiers: ['query'], text: 'Is this safe?' }, + }), + ], + }), + ] + + const result = await provider.countTokens(messages) + + expect(result).toBe(Math.ceil('Is this safe?'.length / 4)) + }) + + it('estimates citations block tokens', async () => { + const provider = new TestModelProvider() + const messages = [ + new Message({ + role: 'assistant', + content: [ + new CitationsBlock({ + citations: [], + content: [{ text: 'cited text here' }], + }), + ], + }), + ] + + const result = await provider.countTokens(messages) + + expect(result).toBe(Math.ceil('cited text here'.length / 4)) + }) + + it('estimates string system prompt tokens', async () => { + const provider = new TestModelProvider() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const result = await provider.countTokens(messages, { + systemPrompt: 'You are a helpful assistant', + }) + + expect(result).toBe(Math.ceil('You are a helpful assistant'.length / 4) + Math.ceil('Hi'.length / 4)) + }) + + it('estimates array system prompt tokens', async () => { + const provider = new TestModelProvider() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const result = await provider.countTokens(messages, { + systemPrompt: [new TextBlock('System instructions')], + }) + + expect(result).toBe(Math.ceil('System instructions'.length / 4) + Math.ceil('Hi'.length / 4)) + }) + + it('estimates tool spec tokens', async () => { + const provider = new TestModelProvider() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + const toolSpecs = [{ name: 'get_weather', description: 'Get weather for a city' }] + + const result = await provider.countTokens(messages, { toolSpecs }) + + const specJson = JSON.stringify(toolSpecs[0]) + expect(result).toBe(Math.ceil('Hi'.length / 4) + Math.ceil(specJson.length / 2)) + }) + + it('returns 0 for empty messages', async () => { + const provider = new TestModelProvider() + + const result = await provider.countTokens([]) + + expect(result).toBe(0) + }) + + it('skips reasoning blocks without text', async () => { + const provider = new TestModelProvider() + const messages = [ + new Message({ + role: 'assistant', + content: [new ReasoningBlock({ signature: 'sig123' })], + }), + ] + + const result = await provider.countTokens(messages) + + expect(result).toBe(0) + }) + + it('estimates guardContent in array system prompt', async () => { + const provider = new TestModelProvider() + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const result = await provider.countTokens(messages, { + systemPrompt: [new GuardContentBlock({ text: { qualifiers: ['query'], text: 'Guard text here' } })], + }) + + expect(result).toBe(Math.ceil('Guard text here'.length / 4) + Math.ceil('Hi'.length / 4)) + }) + + it('accumulates tokens across multiple messages with mixed content', async () => { + const provider = new TestModelProvider() + const messages = [ + new Message({ role: 'user', content: [new TextBlock('What is the weather?')] }), + new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'get_weather', toolUseId: 'id1', input: { city: 'Seattle' } })], + }), + new Message({ + role: 'user', + content: [new ToolResultBlock({ toolUseId: 'id1', status: 'success', content: [new TextBlock('72F')] })], + }), + ] + + const result = await provider.countTokens(messages, { systemPrompt: 'You are helpful' }) + + const expected = + Math.ceil('You are helpful'.length / 4) + + Math.ceil('What is the weather?'.length / 4) + + Math.ceil('get_weather'.length / 4) + + Math.ceil(JSON.stringify({ city: 'Seattle' }).length / 2) + + Math.ceil('72F'.length / 4) + expect(result).toBe(expected) + }) +}) diff --git a/strands-ts/src/models/__tests__/streaming.test.ts b/strands-ts/src/models/__tests__/streaming.test.ts new file mode 100644 index 0000000000..23d445a220 --- /dev/null +++ b/strands-ts/src/models/__tests__/streaming.test.ts @@ -0,0 +1,59 @@ +import { describe, it, expect } from 'vitest' +import { isModelStreamEvent } from '../streaming.js' +import type { ModelStreamEvent } from '../streaming.js' + +describe('isModelStreamEvent', () => { + it('returns true for modelMessageStartEvent', () => { + const event: ModelStreamEvent = { type: 'modelMessageStartEvent', role: 'assistant' } + expect(isModelStreamEvent(event)).toBe(true) + }) + + it('returns true for modelContentBlockStartEvent', () => { + const event: ModelStreamEvent = { type: 'modelContentBlockStartEvent' } + expect(isModelStreamEvent(event)).toBe(true) + }) + + it('returns true for modelContentBlockDeltaEvent', () => { + const event: ModelStreamEvent = { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'hello' }, + } + expect(isModelStreamEvent(event)).toBe(true) + }) + + it('returns true for modelContentBlockStopEvent', () => { + const event: ModelStreamEvent = { type: 'modelContentBlockStopEvent' } + expect(isModelStreamEvent(event)).toBe(true) + }) + + it('returns true for modelMessageStopEvent', () => { + const event: ModelStreamEvent = { type: 'modelMessageStopEvent', stopReason: 'endTurn' } + expect(isModelStreamEvent(event)).toBe(true) + }) + + it('returns true for modelMetadataEvent', () => { + const event: ModelStreamEvent = { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + } + expect(isModelStreamEvent(event)).toBe(true) + }) + + it('returns true for modelRedactionEvent', () => { + const event: ModelStreamEvent = { + type: 'modelRedactionEvent', + inputRedaction: { replaceContent: '[User input redacted.]' }, + } + expect(isModelStreamEvent(event)).toBe(true) + }) + + it('returns false for unknown event types', () => { + const event = { type: 'unknownEvent' } + expect(isModelStreamEvent(event)).toBe(false) + }) + + it('returns false for content block types', () => { + const event = { type: 'textBlock', text: 'hello' } + expect(isModelStreamEvent(event)).toBe(false) + }) +}) diff --git a/strands-ts/src/models/__tests__/test-utils.ts b/strands-ts/src/models/__tests__/test-utils.ts new file mode 100644 index 0000000000..6ebe8bb533 --- /dev/null +++ b/strands-ts/src/models/__tests__/test-utils.ts @@ -0,0 +1,19 @@ +// ABOUTME: Shared test utilities for model tests +// ABOUTME: Contains helper functions for collecting stream events and other common test operations + +import type { ModelStreamEvent } from '../streaming.js' + +/** + * Helper function to collect all events from a stream. + * Useful for testing streaming model responses. + * + * @param stream - An async iterable of ModelStreamEvent + * @returns Promise resolving to an array of all emitted events + */ +export async function collectEvents(stream: AsyncIterable): Promise { + const events: ModelStreamEvent[] = [] + for await (const event of stream) { + events.push(event) + } + return events +} diff --git a/strands-ts/src/models/__tests__/vercel.test.ts b/strands-ts/src/models/__tests__/vercel.test.ts new file mode 100644 index 0000000000..48acbb6fe5 --- /dev/null +++ b/strands-ts/src/models/__tests__/vercel.test.ts @@ -0,0 +1,869 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import type { + LanguageModelV3, + LanguageModelV3CallOptions, + LanguageModelV3StreamPart, + LanguageModelV3StreamResult, +} from '@ai-sdk/provider' +import { APICallError } from '@ai-sdk/provider' +import { VercelModel } from '../vercel.js' +import { ContextWindowOverflowError, ModelError, ModelThrottledError } from '../../errors.js' +import { logger } from '../../logging/logger.js' +import { collectIterator } from '../../__fixtures__/model-test-helpers.js' +import { Message, TextBlock, ToolUseBlock, ToolResultBlock, ReasoningBlock, JsonBlock } from '../../types/messages.js' +import { DocumentBlock, ImageBlock, VideoBlock } from '../../types/media.js' +import type { ToolSpec } from '../../tools/types.js' + +/** + * Creates a mock LanguageModelV3 that streams the given parts. + */ +function createMockModel(parts: LanguageModelV3StreamPart[]): LanguageModelV3 { + return { + specificationVersion: 'v3', + provider: 'test', + modelId: 'test-model', + supportedUrls: {}, + doGenerate: vi.fn(), + doStream: vi.fn( + async (): Promise => ({ + stream: new ReadableStream({ + start(controller) { + for (const part of parts) { + controller.enqueue(part) + } + controller.close() + }, + }), + }) + ), + } +} + +/** Standard usage object for finish events */ +const testUsage = { + inputTokens: { total: 10, noCache: 10, cacheRead: undefined, cacheWrite: undefined }, + outputTokens: { total: 5, noCache: undefined, text: 5, reasoning: undefined }, +} + +/** Standard finish reason */ +const stopFinish = { unified: 'stop' as const, raw: 'stop' } + +/** Minimal stream parts that produce a valid (empty) response */ +const minimalParts: LanguageModelV3StreamPart[] = [ + { type: 'stream-start', warnings: [] }, + { type: 'finish', usage: testUsage, finishReason: stopFinish }, +] + +/** + * Creates a model backed by a mock that streams the given parts, + * collects events, and returns the mock's doStream call args for inspection. + */ +function setupCaptureTest( + parts: LanguageModelV3StreamPart[] = minimalParts, + config?: Parameters[0] +): { + model: VercelModel + mock: LanguageModelV3 + callArgs: () => LanguageModelV3CallOptions + collect: (messages: Message[], options?: Parameters[1]) => ReturnType +} { + const mock = createMockModel(parts) + const model = new VercelModel({ provider: mock, ...config }) + return { + model, + mock, + callArgs: () => (mock.doStream as ReturnType).mock.calls[0]![0] as LanguageModelV3CallOptions, + collect: (messages, options) => collectIterator(model.stream(messages, options)), + } +} + +describe('VercelModel', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('constructor and config', () => { + it('uses model.modelId as default and allows override', () => { + const mock = createMockModel([]) + expect(new VercelModel({ provider: mock }).getConfig().modelId).toBe('test-model') + expect(new VercelModel({ provider: mock, modelId: 'custom-id' }).getConfig().modelId).toBe('custom-id') + }) + + it('passes through all config fields', () => { + const mock = createMockModel([]) + const model = new VercelModel({ + provider: mock, + maxTokens: 100, + temperature: 0.5, + topP: 0.9, + topK: 40, + presencePenalty: 0.5, + frequencyPenalty: 0.3, + stopSequences: ['END'], + seed: 42, + }) + expect(model.getConfig()).toStrictEqual({ + modelId: 'test-model', + maxTokens: 100, + temperature: 0.5, + topP: 0.9, + topK: 40, + presencePenalty: 0.5, + frequencyPenalty: 0.3, + stopSequences: ['END'], + seed: 42, + }) + }) + + it('updateConfig merges config and getConfig returns a copy', () => { + const mock = createMockModel([]) + const model = new VercelModel({ provider: mock }) + model.updateConfig({ modelId: 'updated', maxTokens: 200 }) + const config1 = model.getConfig() + const config2 = model.getConfig() + expect(config1).toStrictEqual({ modelId: 'updated', maxTokens: 200 }) + expect(config1).not.toBe(config2) + }) + }) + + describe('stream', () => { + describe('text streaming', () => { + it('emits correct events for simple text response', async () => { + const { model } = setupCaptureTest([ + { type: 'stream-start', warnings: [] }, + { type: 'text-start', id: 't1' }, + { type: 'text-delta', id: 't1', delta: 'Hello' }, + { type: 'text-delta', id: 't1', delta: ' world' }, + { type: 'text-end', id: 't1' }, + { type: 'finish', usage: testUsage, finishReason: stopFinish }, + ]) + + const events = await collectIterator(model.stream([])) + + expect(events[0]).toMatchObject({ type: 'modelMessageStartEvent', role: 'assistant' }) + expect(events[1]).toMatchObject({ type: 'modelContentBlockStartEvent' }) + expect(events[2]).toMatchObject({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'Hello' }, + }) + expect(events[3]).toMatchObject({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: ' world' }, + }) + expect(events[4]).toMatchObject({ type: 'modelContentBlockStopEvent' }) + expect(events[5]).toMatchObject({ type: 'modelMetadataEvent' }) + expect(events[6]).toMatchObject({ type: 'modelMessageStopEvent', stopReason: 'endTurn' }) + }) + }) + + describe('reasoning streaming', () => { + it('emits reasoning content delta events', async () => { + const { model } = setupCaptureTest([ + { type: 'stream-start', warnings: [] }, + { type: 'reasoning-start', id: 'r1' }, + { type: 'reasoning-delta', id: 'r1', delta: 'Let me think...' }, + { type: 'reasoning-end', id: 'r1' }, + { type: 'text-start', id: 't1' }, + { type: 'text-delta', id: 't1', delta: 'Answer' }, + { type: 'text-end', id: 't1' }, + { type: 'finish', usage: testUsage, finishReason: stopFinish }, + ]) + + const events = await collectIterator(model.stream([])) + + const reasoningDelta = events.find( + (e) => e.type === 'modelContentBlockDeltaEvent' && e.delta.type === 'reasoningContentDelta' + ) + expect(reasoningDelta).toMatchObject({ + delta: { type: 'reasoningContentDelta', text: 'Let me think...' }, + }) + }) + }) + + describe('tool call streaming', () => { + it('synthesizes start/delta/stop from complete tool-call part', async () => { + const { model } = setupCaptureTest([ + { type: 'stream-start', warnings: [] }, + { type: 'tool-call', toolCallId: 'call_1', toolName: 'calculator', input: '{"expr":"2+2"}' }, + { type: 'finish', usage: testUsage, finishReason: { unified: 'tool-calls', raw: 'tool_calls' } }, + ]) + + const events = await collectIterator(model.stream([])) + + expect(events[1]).toMatchObject({ + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', name: 'calculator', toolUseId: 'call_1' }, + }) + expect(events[2]).toMatchObject({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: '{"expr":"2+2"}' }, + }) + expect(events[3]).toMatchObject({ type: 'modelContentBlockStopEvent' }) + expect(events[5]).toMatchObject({ type: 'modelMessageStopEvent', stopReason: 'toolUse' }) + }) + + it('normalizes object tool-call input to JSON string', async () => { + const { model } = setupCaptureTest([ + { type: 'stream-start', warnings: [] }, + { + type: 'tool-call', + toolCallId: 'call_1', + toolName: 'calculator', + input: { expr: '2+2' } as unknown as string, + }, + { type: 'finish', usage: testUsage, finishReason: { unified: 'tool-calls', raw: 'tool_calls' } }, + ]) + + const events = await collectIterator(model.stream([])) + + expect(events[2]).toMatchObject({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: '{"expr":"2+2"}' }, + }) + }) + + it('skips duplicate tool-call when incremental tool-input events were already emitted', async () => { + const { model } = setupCaptureTest([ + { type: 'stream-start', warnings: [] }, + { type: 'tool-input-start', id: 'call_1', toolName: 'calculator' }, + { type: 'tool-input-delta', id: 'call_1', delta: '{"expr":"2+2"}' }, + { type: 'tool-input-end', id: 'call_1' }, + { type: 'tool-call', toolCallId: 'call_1', toolName: 'calculator', input: '{"expr":"2+2"}' }, + { type: 'finish', usage: testUsage, finishReason: { unified: 'tool-calls', raw: 'tool_calls' } }, + ]) + + const events = await collectIterator(model.stream([])) + + const toolStarts = events.filter( + (e) => e.type === 'modelContentBlockStartEvent' && e.start?.type === 'toolUseStart' + ) + expect(toolStarts).toHaveLength(1) + }) + + it('emits tool use start/delta/stop events', async () => { + const { model } = setupCaptureTest([ + { type: 'stream-start', warnings: [] }, + { type: 'tool-input-start', id: 'call_1', toolName: 'calculator' }, + { type: 'tool-input-delta', id: 'call_1', delta: '{"expr' }, + { type: 'tool-input-delta', id: 'call_1', delta: '":"2+2"}' }, + { type: 'tool-input-end', id: 'call_1' }, + { type: 'finish', usage: testUsage, finishReason: { unified: 'tool-calls', raw: 'tool_calls' } }, + ]) + + const events = await collectIterator(model.stream([])) + + expect(events[1]).toMatchObject({ + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', name: 'calculator', toolUseId: 'call_1' }, + }) + expect(events[2]).toMatchObject({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: '{"expr' }, + }) + expect(events[3]).toMatchObject({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: '":"2+2"}' }, + }) + expect(events[4]).toMatchObject({ type: 'modelContentBlockStopEvent' }) + expect(events[6]).toMatchObject({ type: 'modelMessageStopEvent', stopReason: 'toolUse' }) + }) + }) + + describe('finish reasons', () => { + it.each([ + ['stop', 'endTurn'], + ['length', 'maxTokens'], + ['content-filter', 'contentFiltered'], + ['tool-calls', 'toolUse'], + ['other', 'endTurn'], + ] as const)('maps Language Model "%s" to Strands "%s"', async (unified, expected) => { + const { model } = setupCaptureTest([ + { type: 'stream-start', warnings: [] }, + { type: 'finish', usage: testUsage, finishReason: { unified, raw: unified } }, + ]) + + const events = await collectIterator(model.stream([])) + const stopEvent = events.find((e) => e.type === 'modelMessageStopEvent') + expect(stopEvent?.stopReason).toBe(expected) + }) + + it('throws ModelError for error finish reason', async () => { + const { model } = setupCaptureTest([ + { type: 'stream-start', warnings: [] }, + { type: 'finish', usage: testUsage, finishReason: { unified: 'error', raw: 'internal_error' } }, + ]) + + await expect(collectIterator(model.stream([]))).rejects.toThrow(ModelError) + }) + }) + + describe('usage mapping', () => { + it('maps usage with cache tokens', async () => { + const { model } = setupCaptureTest([ + { type: 'stream-start', warnings: [] }, + { + type: 'finish', + usage: { + inputTokens: { total: 100, noCache: 80, cacheRead: 15, cacheWrite: 5 }, + outputTokens: { total: 50, text: 40, reasoning: 10 }, + }, + finishReason: stopFinish, + }, + ]) + + const events = await collectIterator(model.stream([])) + const metaEvent = events.find((e) => e.type === 'modelMetadataEvent') + + expect(metaEvent?.usage).toEqual({ + inputTokens: 100, + outputTokens: 50, + totalTokens: 150, + cacheReadInputTokens: 15, + cacheWriteInputTokens: 5, + }) + }) + + it('handles undefined token counts', async () => { + const { model } = setupCaptureTest([ + { type: 'stream-start', warnings: [] }, + { + type: 'finish', + usage: { + inputTokens: { total: undefined, noCache: undefined, cacheRead: undefined, cacheWrite: undefined }, + outputTokens: { total: undefined, text: undefined, reasoning: undefined }, + }, + finishReason: stopFinish, + }, + ]) + + const events = await collectIterator(model.stream([])) + const metaEvent = events.find((e) => e.type === 'modelMetadataEvent') + + expect(metaEvent?.usage).toEqual({ + inputTokens: 0, + outputTokens: 0, + totalTokens: 0, + }) + }) + }) + + describe('error handling', () => { + it('throws ModelError on stream error part', async () => { + const { model } = setupCaptureTest([ + { type: 'stream-start', warnings: [] }, + { type: 'error', error: new Error('rate limit exceeded') }, + ]) + + await expect(collectIterator(model.stream([]))).rejects.toThrow(ModelError) + }) + + it('throws ModelError when doStream fails with generic error', async () => { + const { mock, model } = setupCaptureTest() + ;(mock.doStream as ReturnType).mockRejectedValue(new Error('connection failed')) + + await expect(collectIterator(model.stream([]))).rejects.toThrow( + 'Language model stream error: connection failed' + ) + }) + + it('throws ModelThrottledError for APICallError with status 429', async () => { + const { mock, model } = setupCaptureTest() + ;(mock.doStream as ReturnType).mockRejectedValue( + new APICallError({ + message: 'Too many requests', + url: 'https://api.example.com', + requestBodyValues: {}, + statusCode: 429, + }) + ) + + await expect(collectIterator(model.stream([]))).rejects.toThrow(ModelThrottledError) + }) + + it('throws ContextWindowOverflowError for APICallError with context overflow in responseBody', async () => { + const { mock, model } = setupCaptureTest() + ;(mock.doStream as ReturnType).mockRejectedValue( + new APICallError({ + message: 'Bad request', + url: 'https://api.example.com', + requestBodyValues: {}, + statusCode: 400, + responseBody: 'Input is too long for requested model', + }) + ) + + await expect(collectIterator(model.stream([]))).rejects.toThrow(ContextWindowOverflowError) + }) + + it('throws ContextWindowOverflowError for non-APICallError with context overflow message', async () => { + const { mock, model } = setupCaptureTest() + ;(mock.doStream as ReturnType).mockRejectedValue( + new Error('context_length_exceeded: maximum context length is 128000') + ) + + await expect(collectIterator(model.stream([]))).rejects.toThrow(ContextWindowOverflowError) + }) + + it('classifies errors thrown during reader.read()', async () => { + const mock = createMockModel([]) + ;(mock.doStream as ReturnType).mockResolvedValue({ + stream: new ReadableStream({ + start(controller) { + controller.enqueue({ type: 'stream-start', warnings: [] }) + controller.error( + new APICallError({ + message: 'Too many requests', + url: 'https://api.example.com', + requestBodyValues: {}, + statusCode: 429, + }) + ) + }, + }), + }) + const model = new VercelModel({ provider: mock }) + + await expect(collectIterator(model.stream([]))).rejects.toThrow(ModelThrottledError) + }) + }) + + describe('call options forwarding', () => { + it('forwards config to doStream', async () => { + const { collect, callArgs } = setupCaptureTest(minimalParts, { + maxTokens: 100, + temperature: 0.7, + topP: 0.95, + topK: 40, + presencePenalty: 0.5, + frequencyPenalty: 0.3, + stopSequences: ['END'], + seed: 42, + }) + await collect([]) + + expect(callArgs()).toMatchObject({ + maxOutputTokens: 100, + temperature: 0.7, + topP: 0.95, + topK: 40, + presencePenalty: 0.5, + frequencyPenalty: 0.3, + stopSequences: ['END'], + seed: 42, + }) + }) + + it('omits undefined config values', async () => { + const { collect, callArgs } = setupCaptureTest() + await collect([]) + + const args = callArgs() + for (const key of [ + 'maxOutputTokens', + 'temperature', + 'topP', + 'topK', + 'presencePenalty', + 'frequencyPenalty', + 'stopSequences', + 'seed', + ]) { + expect(args).not.toHaveProperty(key) + } + }) + }) + + it('logs response-metadata at debug level', async () => { + const debugSpy = vi.spyOn(logger, 'debug').mockImplementation(() => {}) + const { model } = setupCaptureTest([ + { type: 'stream-start', warnings: [] }, + { type: 'text-start', id: 't1' }, + { type: 'text-delta', id: 't1', delta: 'Hi' }, + { type: 'text-end', id: 't1' }, + { type: 'response-metadata', id: 'resp1', timestamp: new Date() } as any, + { type: 'finish', usage: testUsage, finishReason: stopFinish }, + ]) + + const events = await collectIterator(model.stream([])) + expect(events.map((e) => e.type)).not.toContain('response-metadata') + expect(debugSpy).toHaveBeenCalled() + debugSpy.mockRestore() + }) + }) + + describe('message formatting', () => { + describe('system prompt', () => { + it('formats string system prompt', async () => { + const { collect, callArgs } = setupCaptureTest() + await collect([], { systemPrompt: 'You are helpful.' }) + + expect(callArgs().prompt[0]).toEqual({ role: 'system', content: 'You are helpful.' }) + }) + + it('formats system prompt content blocks', async () => { + const { collect, callArgs } = setupCaptureTest() + await collect([], { systemPrompt: [{ text: 'Part 1' }, { text: 'Part 2' }] as any }) + + expect(callArgs().prompt[0]).toEqual({ role: 'system', content: 'Part 1Part 2' }) + }) + + it('ignores cache points in system prompt', async () => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const { collect, callArgs } = setupCaptureTest() + await collect([], { + systemPrompt: [ + { type: 'textBlock', text: 'Hello' }, + { type: 'cachePointBlock', cacheType: 'default' }, + ] as any, + }) + + expect(callArgs().prompt[0]).toEqual({ role: 'system', content: 'Hello' }) + expect(warnSpy).toHaveBeenCalled() + warnSpy.mockRestore() + }) + + it('ignores guard content in system prompt', async () => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const { collect, callArgs } = setupCaptureTest() + await collect([], { + systemPrompt: [ + { type: 'textBlock', text: 'Hello' }, + { type: 'guardContentBlock', guardContent: {} }, + ] as any, + }) + + expect(callArgs().prompt[0]).toEqual({ role: 'system', content: 'Hello' }) + expect(warnSpy).toHaveBeenCalled() + warnSpy.mockRestore() + }) + }) + + describe('user messages', () => { + it('formats user text message', async () => { + const { collect, callArgs } = setupCaptureTest() + await collect([new Message({ role: 'user', content: [new TextBlock('Hello')] })]) + + const userMsg = callArgs().prompt[0] as any + expect(userMsg.role).toBe('user') + expect(userMsg.content[0]).toEqual({ type: 'text', text: 'Hello' }) + }) + + it('formats image blocks with bytes and URL sources', async () => { + const { collect, callArgs } = setupCaptureTest() + await collect([ + new Message({ + role: 'user', + content: [ + new ImageBlock({ format: 'png', source: { bytes: new Uint8Array([1, 2, 3]) } }), + new ImageBlock({ format: 'png', source: { url: 'https://example.com/image.png' } }), + ], + }), + ]) + + const userMsg = callArgs().prompt[0] as any + expect(userMsg.content[0]).toMatchObject({ type: 'file', mediaType: 'image/png' }) + expect(userMsg.content[0].data).toBeInstanceOf(Uint8Array) + expect(userMsg.content[1]).toMatchObject({ type: 'file', mediaType: 'image/png' }) + expect(userMsg.content[1].data).toBeInstanceOf(URL) + expect(userMsg.content[1].data.href).toBe('https://example.com/image.png') + }) + + it('formats document content block source as text parts', async () => { + const { collect, callArgs } = setupCaptureTest() + await collect([ + new Message({ + role: 'user', + content: [ + new DocumentBlock({ + format: 'txt', + name: 'doc', + source: { content: [{ text: 'paragraph 1' }, { text: 'paragraph 2' }] }, + }), + ], + }), + ]) + + const userMsg = callArgs().prompt[0] as any + expect(userMsg.content).toHaveLength(2) + expect(userMsg.content[0]).toEqual({ type: 'text', text: 'paragraph 1' }) + expect(userMsg.content[1]).toEqual({ type: 'text', text: 'paragraph 2' }) + }) + + it('formats video bytes in user messages', async () => { + const { collect, callArgs } = setupCaptureTest() + await collect([ + new Message({ + role: 'user', + content: [new VideoBlock({ format: 'mp4', source: { bytes: new Uint8Array([1, 2]) } })], + }), + ]) + + const userMsg = callArgs().prompt[0] as any + expect(userMsg.content[0]).toMatchObject({ type: 'file', mediaType: 'video/mp4' }) + }) + + it.each([ + { + name: 'image S3 source', + block: new ImageBlock({ + format: 'png', + source: { location: { type: 's3', uri: 's3://bucket/key', bucketOwner: '' } }, + }), + }, + { + name: 'video S3 source', + block: new VideoBlock({ + format: 'mp4', + source: { location: { type: 's3', uri: 's3://bucket/video', bucketOwner: '' } }, + }), + }, + ])('skips unsupported $name', async ({ block }) => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const { collect, callArgs } = setupCaptureTest() + await collect([new Message({ role: 'user', content: [block] })]) + + expect(callArgs().prompt).toHaveLength(0) + expect(warnSpy).toHaveBeenCalled() + warnSpy.mockRestore() + }) + }) + + describe('assistant messages', () => { + it('formats text and tool use blocks', async () => { + const { collect, callArgs } = setupCaptureTest() + await collect([ + new Message({ + role: 'assistant', + content: [ + new TextBlock('Let me calculate'), + new ToolUseBlock({ name: 'calc', toolUseId: 'tu1', input: { x: 1 } }), + ], + }), + ]) + + const prompt = callArgs().prompt + expect(prompt).toHaveLength(1) + const assistantMsg = prompt[0] as any + expect(assistantMsg.role).toBe('assistant') + expect(assistantMsg.content).toHaveLength(2) + expect(assistantMsg.content[0]).toEqual({ type: 'text', text: 'Let me calculate' }) + expect(assistantMsg.content[1].type).toBe('tool-call') + expect(assistantMsg.content[1].toolCallId).toBe('tu1') + }) + + it('formats reasoning blocks', async () => { + const { collect, callArgs } = setupCaptureTest() + await collect([ + new Message({ + role: 'assistant', + content: [new ReasoningBlock({ text: 'thinking...' })], + }), + ]) + + const assistantMsg = callArgs().prompt[0] as any + expect(assistantMsg.content[0]).toEqual({ type: 'reasoning', text: 'thinking...' }) + }) + + it('warns and skips tool results in assistant messages', async () => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const { collect, callArgs } = setupCaptureTest() + await collect([ + new Message({ + role: 'assistant', + content: [ + new ToolUseBlock({ name: 'calc', toolUseId: 'tu1', input: {} }), + new ToolResultBlock({ toolUseId: 'tu1', status: 'success', content: [new TextBlock('42')] }), + ], + }), + ]) + + const prompt = callArgs().prompt + expect(prompt).toHaveLength(1) + const assistantMsg = prompt[0] as any + expect(assistantMsg.content).toHaveLength(1) + expect(assistantMsg.content[0].type).toBe('tool-call') + expect(warnSpy).toHaveBeenCalled() + warnSpy.mockRestore() + }) + + it('handles assistant message with no tool results', async () => { + const { collect, callArgs } = setupCaptureTest() + await collect([new Message({ role: 'assistant', content: [new TextBlock('Just text')] })]) + + const prompt = callArgs().prompt + expect(prompt).toHaveLength(1) + expect((prompt[0] as any).role).toBe('assistant') + }) + }) + describe('tool result output formatting', () => { + function toolResultMessages( + content: ToolResultBlock['content'], + status: 'success' | 'error' = 'success' + ): Message[] { + return [ + new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'tool', toolUseId: 'tu1', input: {} })], + }), + new Message({ + role: 'user', + content: [new ToolResultBlock({ toolUseId: 'tu1', status, content })], + }), + ] + } + + async function getToolOutput(content: ToolResultBlock['content'], status?: 'success' | 'error'): Promise { + const { collect, callArgs } = setupCaptureTest() + await collect(toolResultMessages(content, status)) + return (callArgs().prompt.find((m: any) => m.role === 'tool') as any).content[0].output + } + + it('formats error status with text and fallback', async () => { + expect(await getToolOutput([new TextBlock('boom')], 'error')).toStrictEqual({ + type: 'error-text', + value: 'boom', + }) + expect(await getToolOutput([], 'error')).toStrictEqual({ + type: 'error-text', + value: 'Tool execution failed', + }) + }) + + it.each([ + { name: 'text', content: [new TextBlock('result')], expected: [{ type: 'text', text: 'result' }] }, + { + name: 'json', + content: [new JsonBlock({ json: { k: 'v' } })], + expected: [{ type: 'text', text: '{"k":"v"}' }], + }, + { + name: 'image URL', + content: [new ImageBlock({ format: 'png', source: { url: 'https://example.com/img.png' } })], + expected: [{ type: 'text', text: 'https://example.com/img.png' }], + }, + { + name: 'document text', + content: [new DocumentBlock({ format: 'txt', name: 'd', source: { text: 'doc' } })], + expected: [{ type: 'text', text: 'doc' }], + }, + { + name: 'document content blocks', + content: [ + new DocumentBlock({ format: 'txt', name: 'd', source: { content: [{ text: 'p1' }, { text: 'p2' }] } }), + ], + expected: [ + { type: 'text', text: 'p1' }, + { type: 'text', text: 'p2' }, + ], + }, + ])('formats $name content as text', async ({ content, expected }) => { + expect(await getToolOutput(content)).toStrictEqual({ type: 'content', value: expected }) + }) + + it.each([ + { + name: 'image bytes', + content: new ImageBlock({ format: 'png', source: { bytes: new Uint8Array([1]) } }), + mediaType: 'image/png', + }, + { + name: 'document bytes', + content: new DocumentBlock({ format: 'pdf', name: 'd', source: { bytes: new Uint8Array([1]) } }), + mediaType: 'application/pdf', + }, + { + name: 'video bytes', + content: new VideoBlock({ format: 'mp4', source: { bytes: new Uint8Array([1]) } }), + mediaType: 'video/mp4', + }, + ])('formats $name as file-data', async ({ content, mediaType }) => { + const output = await getToolOutput([content]) + expect(output.value[0]).toMatchObject({ type: 'file-data', mediaType }) + }) + + it.each([ + { + name: 'image S3', + block: new ImageBlock({ + format: 'png', + source: { location: { type: 's3', uri: 's3://b/k', bucketOwner: '' } }, + }), + }, + { + name: 'document S3', + block: new DocumentBlock({ + format: 'pdf', + name: 'd', + source: { location: { type: 's3', uri: 's3://b/k', bucketOwner: '' } }, + } as any), + }, + { + name: 'video S3', + block: new VideoBlock({ + format: 'mp4', + source: { location: { type: 's3', uri: 's3://b/k', bucketOwner: '' } }, + }), + }, + ])('warns on unsupported $name source', async ({ block }) => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + await getToolOutput([block]) + expect(warnSpy).toHaveBeenCalled() + warnSpy.mockRestore() + }) + }) + }) + + describe('tool formatting', () => { + it('formats tool specs', async () => { + const tools: ToolSpec[] = [ + { + name: 'calculator', + description: 'Does math', + inputSchema: { type: 'object', properties: { expr: { type: 'string' } }, required: ['expr'] }, + }, + ] + + const { collect, callArgs } = setupCaptureTest() + await collect([], { toolSpecs: tools }) + + expect(callArgs().tools![0]).toMatchObject({ + type: 'function', + name: 'calculator', + description: 'Does math', + }) + }) + + it('handles tool spec with no inputSchema', async () => { + const tools: ToolSpec[] = [{ name: 'noop', description: 'Does nothing' }] + + const { collect, callArgs } = setupCaptureTest() + await collect([], { toolSpecs: tools }) + + const tool = callArgs().tools![0]! + expect(tool.type).toBe('function') + if (tool.type === 'function') { + expect(tool.inputSchema).toEqual({ type: 'object', properties: {} }) + } + }) + + it.each([ + { name: 'auto', input: { auto: {} }, expected: { type: 'auto' } }, + { name: 'any -> required', input: { any: {} }, expected: { type: 'required' } }, + { name: 'specific tool', input: { tool: { name: 'calc' } }, expected: { type: 'tool', toolName: 'calc' } }, + ])('maps toolChoice $name', async ({ input, expected }) => { + const { collect, callArgs } = setupCaptureTest() + await collect([], { toolChoice: input }) + + expect(callArgs().toolChoice).toEqual(expected) + }) + + it('omits tools when not provided', async () => { + const { collect, callArgs } = setupCaptureTest() + await collect([]) + + const args = callArgs() + expect(args).not.toHaveProperty('tools') + expect(args).not.toHaveProperty('toolChoice') + }) + }) +}) diff --git a/strands-ts/src/models/anthropic.ts b/strands-ts/src/models/anthropic.ts new file mode 100644 index 0000000000..d44356ce2b --- /dev/null +++ b/strands-ts/src/models/anthropic.ts @@ -0,0 +1,577 @@ +import Anthropic, { type ClientOptions } from '@anthropic-ai/sdk' +import { + Model, + type BaseModelConfig, + type CountTokensOptions, + type StreamOptions, + resolveConfigMetadata, +} from '../models/model.js' +import type { Message, ContentBlock } from '../types/messages.js' +import type { ModelStreamEvent } from '../models/streaming.js' +import { createEmptyUsage } from '../models/streaming.js' +import { ContextWindowOverflowError, ModelThrottledError, normalizeError } from '../errors.js' +import type { ImageBlock, DocumentBlock } from '../types/media.js' +import { encodeBase64 } from '../types/media.js' +import { logger } from '../logging/logger.js' +import { warnOnce } from '../logging/warn-once.js' +import { MODEL_DEFAULTS, defaultMaxTokensWarningMessage, defaultModelWarningMessage } from './defaults.js' + +const CONTEXT_WINDOW_OVERFLOW_ERRORS = [ + 'prompt is too long', + 'max_tokens exceeded', + 'input too long', + 'input is too long', + 'input length exceeds context window', + 'input and output tokens exceed your context limit', +] +const TEXT_FILE_FORMATS = ['txt', 'md', 'markdown', 'csv', 'json', 'xml', 'html', 'yml', 'yaml', 'js', 'ts', 'py'] + +export interface AnthropicModelConfig extends BaseModelConfig { + /** + * Maximum number of tokens the model can generate in a response. + * + * @defaultValue 64000 — subject to change between versions. + * Set this explicitly to avoid unexpected changes. + */ + maxTokens?: number + stopSequences?: string[] + params?: Record + + /** + * Beta features to enable via the `anthropic-beta` header. + * + * No header is sent by default. Provide a list of beta identifiers to opt into + * features such as `interleaved-thinking-2025-05-14` or `mcp-client-2025-11-20`. + * + * @see https://docs.anthropic.com/en/api/beta-headers + */ + betas?: string[] + + /** + * Whether to use the native Anthropic countTokens API. + * + * When `true`, `countTokens()` calls the Anthropic token counting API for + * accurate counts. When `false` or not set (default), skips the API call and uses + * the character-based heuristic estimator. + * + * @defaultValue false + */ + useNativeTokenCount?: boolean +} + +export interface AnthropicModelOptions extends AnthropicModelConfig { + apiKey?: string + client?: Anthropic + clientConfig?: ClientOptions +} + +export class AnthropicModel extends Model { + private _config: AnthropicModelConfig + private _client: Anthropic + + constructor(options?: AnthropicModelOptions) { + super() + const { apiKey, client, clientConfig, ...modelConfig } = options || {} + + this._config = { + modelId: MODEL_DEFAULTS.anthropic.modelId, + maxTokens: MODEL_DEFAULTS.anthropic.maxTokens, + ...modelConfig, + } + + if (modelConfig.modelId === undefined) { + warnOnce(logger, defaultModelWarningMessage(MODEL_DEFAULTS.anthropic.modelId)) + } + + if (modelConfig.maxTokens === undefined) { + warnOnce(logger, defaultMaxTokensWarningMessage(MODEL_DEFAULTS.anthropic.maxTokens)) + } + + if (client) { + this._client = client + } else { + const hasEnvKey = + typeof process !== 'undefined' && typeof process.env !== 'undefined' && process.env.ANTHROPIC_API_KEY + + if (!apiKey && !hasEnvKey) { + throw new Error( + "Anthropic API key is required. Provide it via the 'apiKey' option or set the ANTHROPIC_API_KEY environment variable." + ) + } + + this._client = new Anthropic({ + ...(apiKey ? { apiKey } : {}), + ...clientConfig, + }) + } + } + + updateConfig(modelConfig: AnthropicModelConfig): void { + this._config = { ...this._config, ...modelConfig } + } + + getConfig(): AnthropicModelConfig { + return resolveConfigMetadata(this._config, this._config.modelId ?? MODEL_DEFAULTS.anthropic.modelId) + } + + /** + * Count tokens using Anthropic's native countTokens API. + * + * Uses the same message format as the Messages API to get accurate token counts + * directly from the Anthropic service. Falls back to the base class heuristic on failure. + * + * @param messages - Array of conversation messages to count tokens for + * @param options - Optional options containing system prompt and tool specs + * @returns Total input token count + */ + override async countTokens(messages: Message[], options?: CountTokensOptions): Promise { + if (this._config.useNativeTokenCount !== true) return super.countTokens(messages, options) + + try { + const request = this._formatRequest(messages, options) + const params: Anthropic.MessageCountTokensParams = { + model: request.model, + messages: request.messages, + ...(request.system && { system: request.system }), + ...(request.tools && { tools: request.tools }), + ...(request.tool_choice && { tool_choice: request.tool_choice }), + } + + const requestOptions = this._buildRequestOptions() + const response = requestOptions + ? await this._client.messages.countTokens(params, requestOptions) + : await this._client.messages.countTokens(params) + + logger.debug(`total_tokens=<${response.input_tokens}> | native token count`) + return response.input_tokens + } catch (error) { + logger.debug(`error=<${error}> | native token counting failed, falling back to estimation`) + return super.countTokens(messages, options) + } + } + + async *stream(messages: Message[], options?: StreamOptions): AsyncIterable { + try { + const request = this._formatRequest(messages, options) + const requestOptions = this._buildRequestOptions() + const stream = requestOptions + ? this._client.messages.stream(request, requestOptions) + : this._client.messages.stream(request) + + const usage = createEmptyUsage() + + let stopReason = 'endTurn' + + for await (const event of stream) { + switch (event.type) { + case 'message_start': { + usage.inputTokens = event.message.usage.input_tokens + + const rawUsage = event.message.usage as unknown as Record + if (rawUsage.cache_creation_input_tokens !== undefined) { + usage.cacheWriteInputTokens = rawUsage.cache_creation_input_tokens + } + if (rawUsage.cache_read_input_tokens !== undefined) { + usage.cacheReadInputTokens = rawUsage.cache_read_input_tokens + } + + yield { + type: 'modelMessageStartEvent', + role: event.message.role, + } + break + } + + case 'content_block_start': + if (event.content_block.type === 'tool_use') { + yield { + type: 'modelContentBlockStartEvent', + start: { + type: 'toolUseStart', + name: event.content_block.name, + toolUseId: event.content_block.id, + }, + } + } else if (event.content_block.type === 'thinking') { + yield { type: 'modelContentBlockStartEvent' } + if (event.content_block.thinking) { + yield { + type: 'modelContentBlockDeltaEvent', + delta: { + type: 'reasoningContentDelta', + text: event.content_block.thinking, + signature: event.content_block.signature, + }, + } + } + } else if (event.content_block.type === 'redacted_thinking') { + yield { type: 'modelContentBlockStartEvent' } + yield { + type: 'modelContentBlockDeltaEvent', + delta: { + type: 'reasoningContentDelta', + redactedContent: event.content_block.data as unknown as Uint8Array, + }, + } + } else { + yield { type: 'modelContentBlockStartEvent' } + if (event.content_block.type === 'text' && event.content_block.text) { + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: event.content_block.text }, + } + } + } + break + + case 'content_block_delta': + if (event.delta.type === 'text_delta') { + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: event.delta.text }, + } + } else if (event.delta.type === 'input_json_delta') { + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: event.delta.partial_json }, + } + } else if (event.delta.type === 'thinking_delta') { + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'reasoningContentDelta', text: event.delta.thinking }, + } + } else if (event.delta.type === 'signature_delta') { + yield { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'reasoningContentDelta', signature: event.delta.signature }, + } + } + break + + case 'content_block_stop': + yield { type: 'modelContentBlockStopEvent' } + break + + case 'message_delta': + if (event.usage) { + usage.outputTokens = event.usage.output_tokens + } + if (event.delta.stop_reason) { + stopReason = this._mapStopReason(event.delta.stop_reason) + } + break + + case 'message_stop': + usage.totalTokens = usage.inputTokens + usage.outputTokens + yield { + type: 'modelMetadataEvent', + usage, + } + yield { + type: 'modelMessageStopEvent', + stopReason, + } + break + } + } + } catch (unknownError) { + const error = normalizeError(unknownError) + + const lowerMessage = error.message.toLowerCase() + if (CONTEXT_WINDOW_OVERFLOW_ERRORS.some((msg) => lowerMessage.includes(msg))) { + throw new ContextWindowOverflowError(error.message) + } + + const err = unknownError as Error & { status?: number } + if (err.status === 429) { + const message = error.message ?? 'Request was throttled by the model provider' + logger.debug(`throttled | error_message=<${message}>`) + throw new ModelThrottledError(message, { cause: err }) + } + + throw error + } + } + + private _buildRequestOptions(): Anthropic.RequestOptions | undefined { + const betas = this._config.betas + if (!betas || betas.length === 0) return undefined + return { headers: { 'anthropic-beta': betas.join(',') } } + } + + private _formatRequest(messages: Message[], options?: StreamOptions): Anthropic.MessageStreamParams { + if (!this._config.modelId) throw new Error('Model ID is required') + + const request: Anthropic.MessageStreamParams = { + model: this._config.modelId, + max_tokens: this._config.maxTokens ?? MODEL_DEFAULTS.anthropic.maxTokens, + messages: this._formatMessages(messages), + stream: true, + } + + if (options?.systemPrompt) { + if (typeof options.systemPrompt === 'string') { + request.system = options.systemPrompt + } else if (Array.isArray(options.systemPrompt)) { + const systemBlocks: Anthropic.TextBlockParam[] = [] + for (let i = 0; i < options.systemPrompt.length; i++) { + const block = options.systemPrompt[i] + if (!block) continue + + if (block.type === 'textBlock') { + const nextBlock = options.systemPrompt[i + 1] + const cacheControl = nextBlock?.type === 'cachePointBlock' ? { type: 'ephemeral' as const } : undefined + + systemBlocks.push({ + type: 'text', + text: block.text, + ...(cacheControl && { cache_control: cacheControl }), + }) + + if (cacheControl) i++ + } else if (block.type === 'guardContentBlock') { + logger.warn( + 'block_type= | guard content not supported in anthropic system prompt | skipping' + ) + } + } + if (systemBlocks.length > 0) request.system = systemBlocks + } + } + + if (options?.toolSpecs?.length) { + request.tools = options.toolSpecs.map((tool) => ({ + name: tool.name, + description: tool.description, + input_schema: tool.inputSchema as Anthropic.Tool.InputSchema, + })) + + if (options.toolChoice) { + if ('auto' in options.toolChoice) { + request.tool_choice = { type: 'auto' } + } else if ('any' in options.toolChoice) { + request.tool_choice = { type: 'any' } + } else if ('tool' in options.toolChoice) { + request.tool_choice = { type: 'tool', name: options.toolChoice.tool.name } + } + } + } + + if (this._config.temperature !== undefined) request.temperature = this._config.temperature + if (this._config.topP !== undefined) request.top_p = this._config.topP + if (this._config.stopSequences !== undefined) request.stop_sequences = this._config.stopSequences + if (this._config.params) Object.assign(request, this._config.params) + + return request + } + + private _formatMessages(messages: Message[]): Anthropic.MessageParam[] { + return messages.map((msg) => { + const role = (msg.role as string) === 'tool' ? 'user' : msg.role + + const content: Anthropic.ContentBlockParam[] = [] + + for (let i = 0; i < msg.content.length; i++) { + const block = msg.content[i] + if (!block) continue + + const nextBlock = msg.content[i + 1] + const hasCachePoint = nextBlock?.type === 'cachePointBlock' + + const formattedBlock = this._formatContentBlock(block) + + if (formattedBlock) { + if (hasCachePoint && this._isCacheableBlock(formattedBlock)) { + formattedBlock.cache_control = { type: 'ephemeral' } + i++ + } + content.push(formattedBlock) + } + } + + return { + role: role as 'user' | 'assistant', + content, + } + }) + } + + private _isCacheableBlock( + block: Anthropic.ContentBlockParam | Anthropic.ToolResultBlockParam + ): block is ( + | Anthropic.TextBlockParam + | Anthropic.ImageBlockParam + | Anthropic.ToolUseBlockParam + | Anthropic.ToolResultBlockParam + | Anthropic.DocumentBlockParam + ) & { cache_control?: { type: 'ephemeral' } } { + return ['text', 'image', 'tool_use', 'tool_result', 'document'].includes(block.type) + } + + private _formatContentBlock( + block: ContentBlock + ): Anthropic.ContentBlockParam | Anthropic.ToolResultBlockParam | undefined { + switch (block.type) { + case 'textBlock': + return { type: 'text', text: block.text } + + case 'imageBlock': { + const imgBlock = block as ImageBlock + let mediaType: 'image/jpeg' | 'image/png' | 'image/gif' | 'image/webp' + + switch (imgBlock.format) { + case 'jpeg': + case 'jpg': + mediaType = 'image/jpeg' + break + case 'png': + mediaType = 'image/png' + break + case 'gif': + mediaType = 'image/gif' + break + case 'webp': + mediaType = 'image/webp' + break + default: + throw new Error(`Unsupported image format for Anthropic: ${imgBlock.format}`) + } + + if (imgBlock.source.type === 'imageSourceBytes') { + return { + type: 'image', + source: { + type: 'base64', + media_type: mediaType, + data: encodeBase64(imgBlock.source.bytes), + }, + } + } + logger.warn('source_type= | anthropic requires image bytes | url sources not fully supported') + return undefined + } + + case 'documentBlock': { + const docBlock = block as DocumentBlock + + if (docBlock.format === 'pdf' && docBlock.source.type === 'documentSourceBytes') { + return { + type: 'document', + source: { + type: 'base64', + media_type: 'application/pdf', + data: encodeBase64(docBlock.source.bytes), + }, + ...(docBlock.name && { title: docBlock.name }), + } as unknown as Anthropic.ContentBlockParam + } + + if (TEXT_FILE_FORMATS.includes(docBlock.format)) { + let textContent: string | undefined + + if (docBlock.source.type === 'documentSourceText') { + textContent = docBlock.source.text + } else if (docBlock.source.type === 'documentSourceBytes') { + if (typeof TextDecoder !== 'undefined') { + textContent = new TextDecoder().decode(docBlock.source.bytes) + } else { + logger.warn(`format=<${docBlock.format}> | cannot decode document bytes | TextDecoder not available`) + } + } + + if (textContent) { + return { + type: 'text', + text: textContent, + } + } + } + + logger.warn(`format=<${docBlock.format}> | unsupported document format or source for anthropic`) + return undefined + } + + case 'toolUseBlock': + return { + type: 'tool_use', + id: block.toolUseId, + name: block.name, + input: block.input as Record, + } + + case 'videoBlock': + logger.warn('block_type= | video blocks not supported by anthropic, skipping') + return undefined + + case 'toolResultBlock': { + const innerContent = block.content + .map((c) => { + if (c.type === 'textBlock') return { type: 'text' as const, text: c.text } + if (c.type === 'jsonBlock') return { type: 'text' as const, text: JSON.stringify(c.json) } + + // Recursively format any other content block (image, document, video, etc.) + const formatted = this._formatContentBlock(c as unknown as ContentBlock) + return formatted + }) + .filter((c): c is NonNullable => !!c) + + let contentVal: string | Anthropic.ContentBlockParam[] + + const firstItem = innerContent[0] + if (innerContent.length === 1 && firstItem && firstItem.type === 'text') { + contentVal = firstItem.text + } else { + contentVal = innerContent + } + + return { + type: 'tool_result', + tool_use_id: block.toolUseId, + content: contentVal, + is_error: block.status === 'error', + } as Anthropic.ToolResultBlockParam + } + + case 'reasoningBlock': + if (block.text && block.signature) { + return { + type: 'thinking', + thinking: block.text, + signature: block.signature, + } as unknown as Anthropic.ContentBlockParam + } else if (block.redactedContent) { + return { + type: 'redacted_thinking', + data: block.redactedContent, + } as unknown as Anthropic.ContentBlockParam + } + return undefined + + case 'cachePointBlock': + return undefined + + default: + return undefined + } + } + + private _mapStopReason(anthropicReason: string): string { + switch (anthropicReason) { + case 'end_turn': + return 'endTurn' + case 'max_tokens': + return 'maxTokens' + case 'stop_sequence': + return 'stopSequence' + case 'tool_use': + return 'toolUse' + case 'pause_turn': + return 'pauseTurn' + case 'refusal': + return 'refusal' + default: + logger.warn(`stop_reason=<${anthropicReason}> | unknown anthropic stop reason`) + return anthropicReason + } + } +} diff --git a/strands-ts/src/models/bedrock.ts b/strands-ts/src/models/bedrock.ts new file mode 100644 index 0000000000..a606ec4ee1 --- /dev/null +++ b/strands-ts/src/models/bedrock.ts @@ -0,0 +1,1840 @@ +/** + * AWS Bedrock model provider implementation. + * + * This module provides integration with AWS Bedrock's Converse API, + * supporting streaming responses, tool use, and prompt caching. + * + * @see https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseStream.html + */ + +import { + BedrockRuntimeClient, + type BedrockRuntimeClientConfig, + type CachePointBlock as BedrockCachePointBlock, + type CacheTTL as BedrockSdkCacheTTL, + type ContentBlock as BedrockContentBlock, + type ContentBlockDeltaEvent as BedrockContentBlockDeltaEvent, + type ContentBlockStartEvent as BedrockContentBlockStartEvent, + ConverseCommand, + type ConverseCommandOutput, + ConverseStreamCommand, + CountTokensCommand, + type ConverseStreamCommandInput, + type ConverseStreamMetadataEvent as BedrockConverseStreamMetadataEvent, + type ConverseStreamOutput, + type InferenceConfiguration, + type Message as BedrockMessage, + type MessageStartEvent as BedrockMessageStartEvent, + type MessageStopEvent as BedrockMessageStopEvent, + type ReasoningContentBlock, + type ReasoningContentBlockDelta, + type Tool, + type ToolConfiguration, + type ToolUseBlockDelta, + type ImageSource as BedrockImageSource, + type VideoSource as BedrockVideoSource, + type DocumentSource as BedrockDocumentSource, + type SystemContentBlock, + DocumentFormat, + ImageFormat, + VideoFormat, + type BedrockRuntimeClientResolvedConfig, + type CitationLocation as BedrockCitationLocation, + type Citation as BedrockCitation, + type CitationsContentBlock as BedrockCitationsContentBlock, + type CitationsDelta as BedrockCitationsDelta, + type GuardrailTraceAssessment, +} from '@aws-sdk/client-bedrock-runtime' +import { + type BaseModelConfig, + type CacheConfig, + type CountTokensOptions, + Model, + type StreamOptions, + resolveConfigMetadata, +} from '../models/model.js' +import type { ContentBlock, Message, StopReason, ToolUseBlock } from '../types/messages.js' +import type { ImageSource, VideoSource, DocumentSource } from '../types/media.js' +import type { CitationsDelta, ModelStreamEvent, ReasoningContentDelta, Usage } from '../models/streaming.js' +import type { Citation, CitationLocation, CitationsBlockData } from '../types/citations.js' +import type { JSONValue } from '../types/json.js' +import { ContextWindowOverflowError, ModelThrottledError, ProviderTokenCountError, normalizeError } from '../errors.js' +import { ensureDefined } from '../types/validation.js' +import { logger } from '../logging/logger.js' +import { warnOnce } from '../logging/warn-once.js' +import { NOOP_TOOL_SPEC } from '../tools/noop-tool.js' +import { MODEL_DEFAULTS, defaultModelWarningMessage } from './defaults.js' + +const DEFAULT_BEDROCK_REGION_SUPPORTS_FIP = false + +/** + * Default request timeout in milliseconds. The AWS SDK defaults to 0 (disabled), which lets + * a stuck connection hang indefinitely — we pick 120s to bound that. Callers can override + * via `clientConfig.requestHandler.requestTimeout`. + */ +const DEFAULT_REQUEST_TIMEOUT_MS = 120_000 + +/** + * Models that require the status field in tool results. + * According to AWS Bedrock API documentation, the status field is only supported by Anthropic Claude models. + * @see https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html + */ +const MODELS_INCLUDE_STATUS = ['anthropic.claude'] + +/** + * Models that support the Anthropic-style prompt caching strategy. + * Used to auto-detect when `cacheConfig.strategy` is `'auto'`. + * @see https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + */ +const MODELS_SUPPORTING_ANTHROPIC_CACHING = ['anthropic', 'claude'] + +/** + * Error messages that indicate context window overflow. + * Used to detect when input exceeds the model's context window. + */ +const BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [ + 'Input is too long for requested model', + 'input length and `max_tokens` exceed context limit', + 'too many total text bytes', + 'prompt is too long', +] + +/** + * Cache of model IDs for which CountTokens API calls should be skipped. + * Prevents repeated failing API calls that will never succeed for the lifetime of the process. + */ +const SKIP_COUNT_TOKENS_MODELS = new Set() + +/** + * Mapping of Bedrock stop reasons to SDK stop reasons. + */ +const STOP_REASON_MAP = { + end_turn: 'endTurn', + tool_use: 'toolUse', + max_tokens: 'maxTokens', + stop_sequence: 'stopSequence', + content_filtered: 'contentFiltered', + guardrail_intervened: 'guardrailIntervened', +} as const + +/** + * Default message for redacted input. + */ +const DEFAULT_REDACT_INPUT_MESSAGE = '[User input redacted.]' + +/** + * Default message for redacted output. + */ +const DEFAULT_REDACT_OUTPUT_MESSAGE = '[Assistant output redacted.]' + +/** + * TTL durations accepted by Bedrock for prompt-cache checkpoints. + * + * Bedrock currently accepts `'5m'` (default) and `'1h'`. The `(string & {})` branch keeps + * autocomplete on the known values while letting callers pass any string forward — Bedrock + * validates the value server-side and rejects unsupported values with `ValidationException`, + * so this stays correct as AWS adds new TTL values without an SDK update. + * + * Bedrock also requires checkpoint TTLs to be **non-increasing** across + * `toolConfig` → system → messages — setting a longer TTL on a later checkpoint than an + * earlier one will be rejected by the service. + * + * @see https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CachePointBlock.html + */ +export type BedrockCacheTTL = '5m' | '1h' | (string & {}) + +/** + * Bedrock-specific prompt-caching configuration. Narrows the TTL fields onto the common + * {@link CacheConfig} for the Bedrock provider. + */ +export interface BedrockCacheConfig extends CacheConfig { + /** TTL applied to the auto-injected cache point appended after `toolConfig.tools`. */ + toolsTTL?: BedrockCacheTTL + + /** TTL applied to the auto-injected cache point appended to the last user message. */ + messagesTTL?: BedrockCacheTTL +} + +/** + * Redaction configuration for Bedrock guardrails. + * Controls whether and how blocked content is replaced. + */ +export interface BedrockGuardrailRedactionConfig { + /** Redact input when blocked. @defaultValue true */ + input?: boolean + + /** Replacement message for redacted input. @defaultValue '[User input redacted.]' */ + inputMessage?: string + + /** Redact output when blocked. @defaultValue false */ + output?: boolean + + /** Replacement message for redacted output. @defaultValue '[Assistant output redacted.]' */ + outputMessage?: string +} + +/** + * Configuration for Bedrock guardrails. + * + * @see https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails.html + */ +export interface BedrockGuardrailConfig { + /** Guardrail identifier */ + guardrailIdentifier: string + + /** Guardrail version (e.g., "1", "DRAFT") */ + guardrailVersion: string + + /** Trace mode for evaluation. @defaultValue 'enabled' */ + trace?: 'enabled' | 'disabled' | 'enabled_full' + + /** Stream processing mode */ + streamProcessingMode?: 'sync' | 'async' + + /** Redaction behavior when content is blocked */ + redaction?: BedrockGuardrailRedactionConfig + + /** + * Only evaluate the latest user message with guardrails. + * When true, wraps the latest user message's text/image content in guardContent blocks. + * This can improve performance and reduce costs in multi-turn conversations. + * + * @remarks + * The implementation finds the last user message containing text or image content + * (not just the last message), ensuring correct behavior during tool execution cycles + * where toolResult messages may follow the user's actual input. + * + * @defaultValue false + */ + guardLatestUserMessage?: boolean +} + +/** + * Converts a snake_case string to camelCase. + * Used for mapping unknown stop reasons from Bedrock to SDK format. + * + * @param str - Snake case string + * @returns Camel case string + */ +function snakeToCamel(str: string): string { + return str.replace(/_([a-z])/g, (_, letter) => letter.toUpperCase()) +} + +/** + * Configuration interface for AWS Bedrock model provider. + * + * Extends BaseModelConfig with Bedrock-specific configuration options + * for model parameters, caching, and additional request/response fields. + * + * @example + * ```typescript + * const config: BedrockModelConfig = { + * modelId: 'global.anthropic.claude-sonnet-4-6', + * maxTokens: 1024, + * temperature: 0.7, + * cacheConfig: { strategy: 'auto' } + * } + * ``` + */ +export interface BedrockModelConfig extends BaseModelConfig { + /** + * Maximum number of tokens to generate in the response. + * + * @see https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InferenceConfiguration.html + */ + maxTokens?: number + + /** + * Controls randomness in generation. + * + * @see https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InferenceConfiguration.html + */ + temperature?: number + + /** + * Controls diversity via nucleus sampling. + * + * @see https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InferenceConfiguration.html + */ + topP?: number + + /** + * Array of sequences that will stop generation when encountered. + */ + stopSequences?: string[] + + /** + * Configuration for prompt caching. + * + * @see https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + */ + cacheConfig?: BedrockCacheConfig + + /** + * Additional fields to include in the Bedrock request. + */ + additionalRequestFields?: JSONValue + + /** + * Additional response field paths to extract from the Bedrock response. + */ + additionalResponseFieldPaths?: string[] + + /** + * Additional arguments to pass through to the Bedrock Converse API. + * @see https://docs.aws.amazon.com/AWSJavaScriptSDK/v3/latest/client/bedrock-runtime/command/ConverseStreamCommand/ + */ + additionalArgs?: JSONValue + + /** + * Whether or not to stream responses from the model. + * + * This will use the ConverseStream API instead of the Converse API. + * + * @see https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + * @see https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseStream.html + */ + stream?: boolean + + /** + * Flag to include status field in tool results. + * - `true`: Always include status field + * - `false`: Never include status field + * - `'auto'`: Automatically determine based on model ID (default) + */ + includeToolResultStatus?: 'auto' | boolean + + /** + * Guardrail configuration for content filtering and safety controls. + * @see https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails.html + */ + guardrailConfig?: BedrockGuardrailConfig + + /** + * Whether to use the native Bedrock CountTokens API. + * + * When `true`, `countTokens()` calls the Bedrock CountTokens API for + * accurate counts. When `false` or not set (default), skips the API call and uses + * the character-based heuristic estimator. + * + * @defaultValue false + */ + useNativeTokenCount?: boolean +} + +/** + * Options for creating a BedrockModel instance. + */ +export interface BedrockModelOptions extends BedrockModelConfig { + /** + * AWS region to use for the Bedrock service. + */ + region?: string + + /** + * Configuration for the Bedrock Runtime client. + */ + clientConfig?: BedrockRuntimeClientConfig + + /** + * Amazon Bedrock API key for bearer token authentication. + * When provided, requests use the API key instead of SigV4 signing. + * @see https://docs.aws.amazon.com/bedrock/latest/userguide/api-keys.html + */ + apiKey?: string +} + +/** + * AWS Bedrock model provider implementation. + * + * Implements the Model interface for AWS Bedrock using the Converse Stream API. + * Supports streaming responses, tool use, prompt caching, and comprehensive error handling. + * + * @example + * ```typescript + * const provider = new BedrockModel({ + * modelConfig: { + * modelId: 'global.anthropic.claude-sonnet-4-6', + * maxTokens: 1024, + * temperature: 0.7 + * }, + * clientConfig: { + * region: 'us-west-2' + * } + * }) + * + * const messages: Message[] = [ + * { type: 'message', role: 'user', content: [{ type: 'textBlock', text: 'Hello!' }] } + * ] + * + * for await (const event of provider.stream(messages)) { + * if (event.type === 'modelContentBlockDeltaEvent' && event.delta.type === 'textDelta') { + * process.stdout.write(event.delta.text) + * } + * } + * ``` + */ +export class BedrockModel extends Model { + private _config: BedrockModelConfig + private _client: BedrockRuntimeClient + + /** + * Clears the cache of model IDs for which CountTokens is skipped. + * After calling this, the next countTokens invocation will attempt the API again. + * + * @internal + */ + static clearCountTokensCache(): void { + SKIP_COUNT_TOKENS_MODELS.clear() + } + + /** + * Creates a new BedrockModel instance. + * + * @param options - Optional configuration for model and client + * + * @example + * ```typescript + * // Minimal configuration with defaults + * const provider = new BedrockModel({ + * region: 'us-west-2' + * }) + * + * // With model configuration + * const provider = new BedrockModel({ + * region: 'us-west-2', + * modelId: 'global.anthropic.claude-sonnet-4-6', + * maxTokens: 2048, + * temperature: 0.8, + * cacheConfig: { strategy: 'auto' } + * }) + * + * // With client configuration + * const provider = new BedrockModel({ + * region: 'us-east-1', + * clientConfig: { + * credentials: myCredentials + * } + * }) + * ``` + */ + constructor(options?: BedrockModelOptions) { + super() + + const { region, clientConfig, apiKey, ...modelConfig } = options ?? {} + + // Initialize model config with default model ID if not provided + this._config = { + modelId: MODEL_DEFAULTS.bedrock.modelId, + ...modelConfig, + } + + if (modelConfig.modelId === undefined) { + warnOnce(logger, defaultModelWarningMessage(MODEL_DEFAULTS.bedrock.modelId)) + } + + // Build user agent string (extend if provided, otherwise use SDK identifier) + const customUserAgent = clientConfig?.customUserAgent + ? `${clientConfig.customUserAgent} strands-agents-ts-sdk` + : 'strands-agents-ts-sdk' + + this._client = new BedrockRuntimeClient({ + ...(clientConfig ?? {}), + requestHandler: withDefaultRequestTimeout(clientConfig?.requestHandler), + // region takes precedence over clientConfig + ...(region ? { region: region } : {}), + customUserAgent, + }) + + if (apiKey) { + applyApiKey(this._client, apiKey) + } + + applyDefaultRegion(this._client.config) + } + + /** + * Returns the cache strategy for this model based on its model ID. + * Returns the appropriate cache strategy name, or null if automatic caching is not supported. + * + * @returns Cache strategy name or null + */ + private _getCacheStrategy(): 'anthropic' | null { + return MODELS_SUPPORTING_ANTHROPIC_CACHING.some((pattern) => this._config.modelId?.includes(pattern)) + ? 'anthropic' + : null + } + + /** + * Determines if caching should be enabled. + * Returns true when: + * - strategy is 'anthropic' (explicit enable) + * - strategy is 'auto' and model supports caching (auto-detect) + * + * @returns True if caching should be enabled + */ + private _shouldEnableCaching(): boolean { + const cacheConfig = this._config.cacheConfig + if (!cacheConfig) { + return false + } + + let strategy = cacheConfig.strategy + + if (strategy === 'auto') { + const detectedStrategy = this._getCacheStrategy() + if (!detectedStrategy) { + logger.warn( + `model_id=<${this._config.modelId}> | cache_config is enabled but this model does not support automatic caching` + ) + return false + } + strategy = detectedStrategy + } + + return strategy === 'anthropic' + } + + /** + * Updates the model configuration. + * Merges the provided configuration with existing settings. + * + * @param modelConfig - Configuration object with model-specific settings to update + * + * @example + * ```typescript + * // Update temperature and maxTokens + * provider.updateConfig({ + * temperature: 0.9, + * maxTokens: 2048 + * }) + * ``` + */ + updateConfig(modelConfig: BedrockModelConfig): void { + this._config = { ...this._config, ...modelConfig } + } + + /** + * Retrieves the current model configuration. + * + * @returns The current configuration object + * + * @example + * ```typescript + * const config = provider.getConfig() + * console.log(config.modelId) + * ``` + */ + getConfig(): BedrockModelConfig { + return resolveConfigMetadata(this._config, this._config.modelId ?? MODEL_DEFAULTS.bedrock.modelId) + } + + /** + * Count tokens using Bedrock's native CountTokens API. + * + * Uses the same message format as the Converse API to get accurate token counts + * directly from the Bedrock service. Falls back to the base class heuristic on failure. + * + * @param messages - Array of conversation messages to count tokens for + * @param options - Optional options containing system prompt and tool specs + * @returns Total input token count + */ + override async countTokens(messages: Message[], options?: CountTokensOptions): Promise { + if (this._config.useNativeTokenCount !== true) return super.countTokens(messages, options) + + const modelId = this._config.modelId ?? MODEL_DEFAULTS.bedrock.modelId + + if (SKIP_COUNT_TOKENS_MODELS.has(modelId)) { + return super.countTokens(messages, options) + } + + try { + const request = this._formatRequest(messages, options) + const converseInput: Record = {} + if (request.messages) converseInput.messages = request.messages + if (request.system) converseInput.system = request.system + if (request.toolConfig) converseInput.toolConfig = request.toolConfig + + const response = await this._client.send( + new CountTokensCommand({ + modelId: this._config.modelId, + input: { converse: converseInput }, + }) + ) + + if (response.inputTokens == null) { + throw new ProviderTokenCountError('Bedrock CountTokens returned undefined for inputTokens') + } + + logger.debug(`total_tokens=<${response.inputTokens}> | native token count`) + return response.inputTokens + } catch (error) { + if (error instanceof Error && error.name === 'AccessDeniedException') { + warnOnce( + logger, + `model_id=<${modelId}> | bedrock:CountTokens permission denied, falling back to heuristic estimation` + ) + SKIP_COUNT_TOKENS_MODELS.add(modelId) + } else if ( + error instanceof Error && + error.name === 'ValidationException' && + error.message.includes("doesn't support counting tokens") + ) { + logger.debug( + `model_id=<${modelId}> | model does not support CountTokens, caching for future calls, falling back to estimation` + ) + SKIP_COUNT_TOKENS_MODELS.add(modelId) + } else { + logger.debug(`error=<${error}> | native token counting failed, falling back to estimation`) + } + return super.countTokens(messages, options) + } + } + + /** + * Streams a conversation with the Bedrock model. + * Returns an async iterable that yields streaming events as they occur. + * + * @param messages - Array of conversation messages + * @param options - Optional streaming configuration + * @returns Async iterable of streaming events + * + * @throws \{ContextWindowOverflowError\} When input exceeds the model's context window + * @throws \{ModelThrottledError\} When Bedrock service throttles requests + * + * @example + * ```typescript + * const messages: Message[] = [ + * { type: 'message', role: $1, content: [{ type: 'textBlock', text: 'What is 2+2?' }] } + * ] + * + * const options: StreamOptions = { + * systemPrompt: 'You are a helpful math assistant.', + * toolSpecs: [calculatorTool] + * } + * + * for await (const event of provider.stream(messages, options)) { + * if (event.type === 'modelContentBlockDeltaEvent') { + * console.log(event.delta) + * } + * } + * ``` + */ + async *stream(messages: Message[], options?: StreamOptions): AsyncIterable { + try { + // Format the request for Bedrock + const request = this._formatRequest(messages, options) + if (this._config.stream !== false) { + // Create and send the command + const command = new ConverseStreamCommand(request) + const response = await this._client.send(command) + // Stream the response + if (response.stream) { + let lastStopReason: string | undefined + for await (const chunk of response.stream) { + // Map Bedrock events to SDK events + const result = this._mapStreamedBedrockEventToSDKEvent(chunk, lastStopReason) + lastStopReason = result.stopReason + for (const event of result.events) { + yield event + } + } + } + } else { + const command = new ConverseCommand(request) + const response = await this._client.send(command) + for (const event of this._mapBedrockEventToSDKEvent(response)) { + yield event + } + } + } catch (unknownError) { + const error = normalizeError(unknownError) + + // Check for context window overflow + if (BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES.some((msg) => error.message.includes(msg))) { + throw new ContextWindowOverflowError(error.message) + } + + // Re-throw other errors as-is + throw error + } + } + + /** + * Formats a request for the Bedrock Converse Stream API. + * + * @param messages - Conversation messages + * @param options - Stream options + * @returns Formatted Bedrock request + */ + private _formatRequest(messages: Message[], options?: StreamOptions): ConverseStreamCommandInput { + const request: ConverseStreamCommandInput = { + modelId: this._config.modelId, + messages: this._formatMessages(messages), + } + + // Add system prompt + if (options?.systemPrompt !== undefined) { + if (typeof options.systemPrompt === 'string') { + request.system = [{ text: options.systemPrompt }] + } else if (options.systemPrompt.length > 0) { + request.system = options.systemPrompt.map((block) => this._formatContentBlock(block) as SystemContentBlock) + } + } + + // Add tool configuration + // Bedrock requires toolConfig when messages contain tool use/result blocks. + // When no tools were provided but messages reference past tool usage (e.g. during + // summarization), inject a noop tool to satisfy the API requirement. + let toolSpecs = options?.toolSpecs ?? [] + if (toolSpecs.length === 0) { + const hasToolBlocks = messages.some((msg) => + msg.content.some((block) => block.type === 'toolUseBlock' || block.type === 'toolResultBlock') + ) + if (hasToolBlocks) { + toolSpecs = [NOOP_TOOL_SPEC] + } + } + + if (toolSpecs.length > 0) { + const tools: Tool[] = toolSpecs.map( + (spec) => + ({ + toolSpec: { + name: spec.name, + description: spec.description, + inputSchema: { json: spec.inputSchema }, + }, + }) as Tool + ) + + if (this._shouldEnableCaching()) { + const cachePoint: BedrockCachePointBlock = { type: 'default' } + const ttl = this._config.cacheConfig?.toolsTTL + if (ttl !== undefined) { + // Bedrock validates TTL values server-side, so accept any string here. + cachePoint.ttl = ttl as BedrockSdkCacheTTL + } + tools.push({ cachePoint }) + } + + const toolConfig: ToolConfiguration = { + tools: tools, + } + + if (options?.toolChoice) { + toolConfig.toolChoice = options.toolChoice + } + + request.toolConfig = toolConfig + } + + // Add inference configuration + const inferenceConfig: InferenceConfiguration = {} + if (this._config.maxTokens !== undefined) inferenceConfig.maxTokens = this._config.maxTokens + if (this._config.temperature !== undefined) inferenceConfig.temperature = this._config.temperature + if (this._config.topP !== undefined) inferenceConfig.topP = this._config.topP + if (this._config.stopSequences !== undefined) inferenceConfig.stopSequences = this._config.stopSequences + + if (Object.keys(inferenceConfig).length > 0) { + request.inferenceConfig = inferenceConfig + } + + // Add additional request fields + const additionalRequestFields = this._getAdditionalRequestFields(options) + if (additionalRequestFields) { + request.additionalModelRequestFields = additionalRequestFields + } + + // Add additional response field paths + if (this._config.additionalResponseFieldPaths) { + request.additionalModelResponseFieldPaths = this._config.additionalResponseFieldPaths + } + + // Add additional args (spread them into the request for forward compatibility) + if (this._config.additionalArgs) { + Object.assign(request, this._config.additionalArgs) + } + + // Add guardrail configuration + if (this._config.guardrailConfig) { + request.guardrailConfig = { + guardrailIdentifier: this._config.guardrailConfig.guardrailIdentifier, + guardrailVersion: this._config.guardrailConfig.guardrailVersion, + trace: this._config.guardrailConfig.trace ?? 'enabled', + ...(this._config.guardrailConfig.streamProcessingMode && { + streamProcessingMode: this._config.guardrailConfig.streamProcessingMode, + }), + } + } + + return request + } + + /** + * Get additional request fields, adjusted for compatibility with the current stream options. + * + * Certain additional request fields are incompatible with specific API options. For example, + * Bedrock does not allow thinking mode when tool_choice forces tool use. + * + * @param options - The stream options for the current request + * @returns The additional request fields, or undefined if none + */ + private _getAdditionalRequestFields(options?: StreamOptions): JSONValue | undefined { + const fields = this._config.additionalRequestFields as Record | undefined + if (!fields || !('thinking' in fields)) { + return fields + } + + const toolChoice = options?.toolChoice + if (!toolChoice || 'auto' in toolChoice) { + return fields + } + + const { thinking: _, ...rest } = fields + return Object.keys(rest).length > 0 ? rest : undefined + } + + /** + * Formats messages for Bedrock API. + * + * @param messages - SDK messages + * @returns Bedrock-formatted messages + */ + private _formatMessages(messages: Message[]): BedrockMessage[] { + // Pre-compute the index of the last user message containing text/image content + // This ensures guardContent wrapping is maintained across tool execution cycles + const lastUserTextIdx = this._config.guardrailConfig?.guardLatestUserMessage + ? this._findLastUserTextMessageIndex(messages) + : undefined + + const formattedMessages = messages.reduce((acc, message, idx) => { + const shouldApplyGuardBlocks = idx === lastUserTextIdx + const content = message.content + .map((block: ContentBlock) => { + const formattedBlock = this._formatContentBlock(block) + return shouldApplyGuardBlocks ? this._applyGuardBlocks(formattedBlock) : formattedBlock + }) + .filter((block) => block !== undefined) + + if (content.length > 0) { + acc.push({ role: message.role, content }) + } + + return acc + }, []) + + // Inject cache point if caching is enabled + if (this._shouldEnableCaching()) { + this._injectCachePoint(formattedMessages) + } + + return formattedMessages + } + + /** + * Inject a cache point at the end of the last user message. + * Strips any existing cache points from all messages first. + * + * @param messages - List of messages to inject cache point into (modified in place) + */ + private _injectCachePoint(messages: BedrockMessage[]): void { + if (messages.length === 0) { + return + } + + let lastUserIdx: number | null = null + + // Strip existing cache points and find last user message + for (let msgIdx = 0; msgIdx < messages.length; msgIdx++) { + const msg = messages[msgIdx] + if (!msg) continue + + const content = msg.content ?? [] + + for (let blockIdx = content.length - 1; blockIdx >= 0; blockIdx--) { + const block = content[blockIdx] + if (block && 'cachePoint' in block) { + content.splice(blockIdx, 1) + logger.warn( + `msg_idx=<${msgIdx}>, block_idx=<${blockIdx}> | stripped existing cache point (auto mode manages cache points)` + ) + } + } + + if (msg.role === 'user') { + lastUserIdx = msgIdx + } + } + + // Add cache point to last user message + if (lastUserIdx !== null) { + const lastMsg = messages[lastUserIdx] + if (lastMsg && lastMsg.content) { + const cachePoint: BedrockCachePointBlock = { type: 'default' } + const ttl = this._config.cacheConfig?.messagesTTL + if (ttl !== undefined) { + // Bedrock validates TTL values server-side, so accept any string here. + cachePoint.ttl = ttl as BedrockSdkCacheTTL + } + lastMsg.content.push({ cachePoint }) + logger.debug(`msg_idx=<${lastUserIdx}> | added cache point to last user message`) + } + } + } + + /** + * Wraps a formatted content block in guardContent for guardrail evaluation. + * + * When guardLatestUserMessage is enabled, this method wraps text and image blocks + * in guardContent blocks to signal to Bedrock's guardrails to evaluate only that content. + * Other content types (toolUse, toolResult, etc.) pass through unchanged. + * + * @param formattedBlock - The formatted content block to potentially wrap + * @returns The block wrapped in guardContent if applicable, or the original block + */ + private _applyGuardBlocks(formattedBlock: BedrockContentBlock | undefined): BedrockContentBlock | undefined { + if (formattedBlock === undefined) { + return undefined + } + + if ('text' in formattedBlock) { + return { + guardContent: { + text: { + text: formattedBlock.text, + }, + }, + } + } + + if ('image' in formattedBlock) { + // Extract image data and validate for guardContent compatibility + const imageBlock = formattedBlock.image + if (!imageBlock?.format || !imageBlock?.source) { + return formattedBlock + } + + const format = imageBlock.format + + // Bedrock guardrails only support png/jpeg formats + if (format !== 'png' && format !== 'jpeg') { + logger.warn( + `image_format=<${format}> | format not supported by bedrock guardrails | skipping guardContent wrap` + ) + return formattedBlock + } + + // Bedrock guardrails only support bytes source (not S3 or URL) + if (!('bytes' in imageBlock.source)) { + logger.warn( + 'source_type= | image source must be bytes for bedrock guardrails | skipping guardContent wrap' + ) + return formattedBlock + } + + return { + guardContent: { + image: { + format: format as 'png' | 'jpeg', + source: imageBlock.source as { bytes: Uint8Array }, + }, + }, + } + } + + // Other content types (toolUse, toolResult, etc.) pass through unchanged + return formattedBlock + } + + /** + * Find the index of the last user message containing text or image content. + * + * This is used for guardLatestUserMessage guardrail evaluation to ensure that guardContent + * wrapping targets the correct message even when toolResult messages (role='user') follow + * the actual user text/image input during tool execution cycles. + * + * @param messages - Array of messages to search + * @returns Index of the last user message with text/image content, or undefined if not found + */ + private _findLastUserTextMessageIndex(messages: Message[]): number | undefined { + for (let idx = messages.length - 1; idx >= 0; idx--) { + const msg = messages[idx] + if (msg === undefined) continue + if ( + msg.role === 'user' && + msg.content.some((block) => block.type === 'textBlock' || block.type === 'imageBlock') + ) { + return idx + } + } + return undefined + } + + /** + * Determines whether to include the status field in tool results. + * + * Uses the includeToolResultStatus config option: + * - If explicitly true, always include status + * - If explicitly false, never include status + * - If 'auto' (default), check if model ID matches known patterns + * + * @returns True if status field should be included, false otherwise + */ + private _shouldIncludeToolResultStatus(): boolean { + const includeStatus = this._config.includeToolResultStatus ?? 'auto' + + if (includeStatus === true) return true + if (includeStatus === false) return false + + // Auto-detection mode: check if modelId contains any pattern + const shouldInclude = MODELS_INCLUDE_STATUS.some((pattern) => this._config.modelId?.includes(pattern)) + + // Log debug message for auto-detection + logger.debug( + `model_id=<${this._config.modelId}>, include_tool_result_status=<${shouldInclude}> | auto-detected includeToolResultStatus` + ) + + return shouldInclude + } + + /** + * Formats a content block for Bedrock API. + * + * @param block - SDK content block + * @returns Bedrock-formatted content block + */ + private _formatContentBlock(block: ContentBlock): BedrockContentBlock | undefined { + switch (block.type) { + case 'textBlock': + return { text: block.text } + + case 'toolUseBlock': + return { + toolUse: { + toolUseId: block.toolUseId, + name: block.name, + input: block.input, + }, + } + + case 'toolResultBlock': { + const content = block.content.map((content) => { + switch (content.type) { + case 'textBlock': + return { text: content.text } + case 'jsonBlock': + return { json: content.json } + case 'imageBlock': + return { + image: { + format: content.format as ImageFormat, + source: this._formatMediaSource(content.source), + }, + } + case 'videoBlock': + return { + video: { + format: content.format === '3gp' ? 'three_gp' : (content.format as VideoFormat), + source: this._formatMediaSource(content.source), + }, + } + case 'documentBlock': + return { + document: { + name: content.name, + format: content.format as DocumentFormat, + source: this._formatDocumentSource(content.source), + ...(content.citations && { citations: content.citations }), + ...(content.context && { context: content.context }), + }, + } + } + }) + + return { + toolResult: { + toolUseId: block.toolUseId, + content, + ...(this._shouldIncludeToolResultStatus() && { status: block.status }), + }, + } + } + + case 'reasoningBlock': { + if (block.text) { + return { + reasoningContent: { + reasoningText: { + text: block.text, + signature: block.signature, + }, + }, + } + } else if (block.redactedContent) { + return { + reasoningContent: { + redactedContent: block.redactedContent, + }, + } + } else { + throw Error("reasoning content format incorrect. Either 'text' or 'redactedContent' must be set.") + } + } + + case 'cachePointBlock': { + const cachePoint: BedrockCachePointBlock = { type: block.cacheType } + if (block.ttl !== undefined) { + // Bedrock validates TTL values server-side, so accept any string here. + cachePoint.ttl = block.ttl as BedrockSdkCacheTTL + } + return { cachePoint } + } + + case 'imageBlock': + return { + image: { + format: block.format as ImageFormat, + source: this._formatMediaSource(block.source), + }, + } + + case 'videoBlock': + return { + video: { + format: block.format === '3gp' ? 'three_gp' : block.format, + source: this._formatMediaSource(block.source), + }, + } + + case 'documentBlock': + return { + document: { + name: block.name, + format: block.format as DocumentFormat, + source: this._formatDocumentSource(block.source), + ...(block.citations && { citations: block.citations }), + ...(block.context && { context: block.context }), + }, + } + + case 'citationsBlock': + return { + citationsContent: { + citations: block.citations.map((c) => this._mapCitationToBedrock(c)), + content: block.content, + }, + } + + case 'guardContentBlock': { + if (block.text) { + return { + guardContent: { + text: { + text: block.text.text, + qualifiers: block.text.qualifiers, + }, + }, + } + } else if (block.image) { + return { + guardContent: { + image: { + format: block.image.format, + source: { bytes: block.image.source.bytes }, + }, + }, + } + } else { + throw new Error('guardContent must have either text or image') + } + } + } + } + + /** + * Format media source (image/video) for Bedrock API. + * Handles bytes, S3 locations, and s3:// URLs. + * + * @param source - Media source + * @returns Formatted source for Bedrock API + */ + private _formatMediaSource( + source: ImageSource | VideoSource + ): + | BedrockImageSource.BytesMember + | BedrockImageSource.S3LocationMember + | BedrockVideoSource.BytesMember + | BedrockVideoSource.S3LocationMember + | undefined { + switch (source.type) { + case 'imageSourceBytes': + case 'videoSourceBytes': + return { bytes: source.bytes } + + case 'imageSourceUrl': + // Check if URL is actually an S3 URI + if (source.url.startsWith('s3://')) { + return { + s3Location: { + uri: source.url, + }, + } + } + logger.warn('source_type= | not supported by bedrock | skipping') + return + + case 'imageSourceS3Location': + case 'videoSourceS3Location': + return { + s3Location: { + uri: source.location.uri, + ...(source.location.bucketOwner && { bucketOwner: source.location.bucketOwner }), + }, + } + + default: + throw new Error('Invalid media source') + } + } + + /** + * Format document source for Bedrock API. + * Handles bytes, text, content, and S3 locations. + * Note: Bedrock API only accepts bytes, content, or s3Location - text is converted to bytes. + * + * @param source - Document source + * @returns Formatted source for Bedrock API + */ + private _formatDocumentSource( + source: DocumentSource + ): BedrockDocumentSource.BytesMember | BedrockDocumentSource.ContentMember | BedrockDocumentSource.S3LocationMember { + switch (source.type) { + case 'documentSourceBytes': + return { + bytes: source.bytes, + } + + case 'documentSourceText': { + // Convert text to bytes - Bedrock API doesn't accept text directly + const encoder = new TextEncoder() + return { bytes: encoder.encode(source.text) } + } + + case 'documentSourceContentBlock': + return { + content: source.content.map((block) => ({ + text: block.text, + })), + } + + case 'documentSourceS3Location': + return { + s3Location: { + uri: source.location.uri, + ...(source.location.bucketOwner && { bucketOwner: source.location.bucketOwner }), + }, + } + + default: + throw new Error('Invalid document source') + } + } + + private _mapBedrockEventToSDKEvent(event: ConverseCommandOutput): ModelStreamEvent[] { + const events: ModelStreamEvent[] = [] + + // Message start + const output = ensureDefined(event.output, 'event.output') + const message = ensureDefined(output.message, 'output.message') + const role = ensureDefined(message.role, 'message.role') + events.push({ + type: 'modelMessageStartEvent', + role, + }) + + // Match on content blocks + const blockHandlers = { + text: (textBlock: string): void => { + events.push({ type: 'modelContentBlockStartEvent' }) + events.push({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: textBlock }, + }) + events.push({ type: 'modelContentBlockStopEvent' }) + }, + toolUse: (block: ToolUseBlock): void => { + events.push({ + type: 'modelContentBlockStartEvent', + start: { + type: 'toolUseStart', + name: ensureDefined(block.name, 'toolUse.name'), + toolUseId: ensureDefined(block.toolUseId, 'toolUse.toolUseId'), + }, + }) + events.push({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: JSON.stringify(ensureDefined(block.input, 'toolUse.input')) }, + }) + events.push({ type: 'modelContentBlockStopEvent' }) + }, + reasoningContent: (block: ReasoningContentBlock): void => { + if (!block) return + events.push({ type: 'modelContentBlockStartEvent' }) + + const delta: ReasoningContentDelta = { type: 'reasoningContentDelta' } + if (block.reasoningText) { + delta.text = ensureDefined(block.reasoningText.text, 'reasoningText.text') + if (block.reasoningText.signature) delta.signature = block.reasoningText.signature + } else if (block.redactedContent) { + delta.redactedContent = block.redactedContent + } + + if (Object.keys(delta).length > 1) { + events.push({ type: 'modelContentBlockDeltaEvent', delta }) + } + + events.push({ type: 'modelContentBlockStopEvent' }) + }, + citationsContent: (block: BedrockCitationsContentBlock): void => { + if (!block) return + events.push({ type: 'modelContentBlockStartEvent' }) + + const mapped = this._mapBedrockCitationsData(block) + const delta: CitationsDelta = { + type: 'citationsDelta', + citations: mapped.citations, + content: mapped.content, + } + events.push({ type: 'modelContentBlockDeltaEvent', delta }) + events.push({ type: 'modelContentBlockStopEvent' }) + }, + } + + const content = ensureDefined(message.content, 'message.content') + content.forEach((block) => { + for (const key in block) { + if (key in blockHandlers) { + const handlerKey = key as keyof typeof blockHandlers + // @ts-expect-error - We know the value type corresponds to the handler key. + blockHandlers[handlerKey](block[handlerKey]) + } else { + logger.warn(`block_key=<${key}> | skipping unsupported block key`) + } + } + }) + + const stopReasonRaw = ensureDefined(event.stopReason, 'event.stopReason') as string + events.push({ + type: 'modelMessageStopEvent', + stopReason: this._transformStopReason(stopReasonRaw, event), + }) + + const usage = ensureDefined(event.usage, 'output.usage') + const metadataEvent: ModelStreamEvent = { + type: 'modelMetadataEvent', + usage: { + inputTokens: ensureDefined(usage.inputTokens, 'usage.inputTokens'), + outputTokens: ensureDefined(usage.outputTokens, 'usage.outputTokens'), + totalTokens: ensureDefined(usage.totalTokens, 'usage.totalTokens'), + }, + } + + if (event.metrics) { + metadataEvent.metrics = { + latencyMs: ensureDefined(event.metrics.latencyMs, 'metrics.latencyMs'), + } + } + + // Handle trace and guardrail check for non-streaming responses + if (event.trace) { + metadataEvent.trace = event.trace + + // Check for blocked guardrails and emit redaction events + if (this._config.guardrailConfig && event.trace.guardrail && stopReasonRaw === 'guardrail_intervened') { + for (const redactionEvent of this._generateRedactionEvents(event.trace.guardrail)) { + events.push(redactionEvent) + } + } + } + + events.push(metadataEvent) + + return events + } + + /** + * Maps a Bedrock event to SDK streaming events. + * + * @param chunk - Bedrock event chunk + * @param lastStopReason - Stop reason from previous messageStop event + * @returns Object containing events array and optional stopReason + */ + private _mapStreamedBedrockEventToSDKEvent( + chunk: ConverseStreamOutput, + lastStopReason?: string + ): { events: ModelStreamEvent[]; stopReason?: string } { + const events: ModelStreamEvent[] = [] + let stopReason = lastStopReason + + // Extract the event type key + const eventType = ensureDefined(Object.keys(chunk)[0], 'eventType') as keyof ConverseStreamOutput + const eventData = chunk[eventType as keyof ConverseStreamOutput] + + switch (eventType) { + case 'messageStart': { + const data = eventData as BedrockMessageStartEvent + events.push({ + type: 'modelMessageStartEvent', + role: ensureDefined(data.role, 'messageStart.role'), + }) + break + } + + case 'contentBlockStart': { + const data = eventData as BedrockContentBlockStartEvent + + const event: ModelStreamEvent = { + type: 'modelContentBlockStartEvent', + } + + if (data.start?.toolUse) { + const toolUse = data.start.toolUse + event.start = { + type: 'toolUseStart', + name: ensureDefined(toolUse.name, 'toolUse.name'), + toolUseId: ensureDefined(toolUse.toolUseId, 'toolUse.toolUseId'), + } + } + + events.push(event) + break + } + + case 'contentBlockDelta': { + const data = eventData as BedrockContentBlockDeltaEvent + const delta = ensureDefined(data.delta, 'contentBlockDelta.delta') + const deltaHandlers = { + text: (textValue: string): void => { + events.push({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: textValue }, + }) + }, + toolUse: (toolUse: ToolUseBlockDelta): void => { + if (!toolUse?.input) return + events.push({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: toolUse.input }, + }) + }, + reasoningContent: (reasoning: ReasoningContentBlockDelta): void => { + if (!reasoning) return + const reasoningDelta: ReasoningContentDelta = { type: 'reasoningContentDelta' } + if (reasoning.text) reasoningDelta.text = reasoning.text + if (reasoning.signature) reasoningDelta.signature = reasoning.signature + if (reasoning.redactedContent) reasoningDelta.redactedContent = reasoning.redactedContent + + if (Object.keys(reasoningDelta).length > 1) { + events.push({ type: 'modelContentBlockDeltaEvent', delta: reasoningDelta }) + } + }, + citation: (citation: BedrockCitationsDelta): void => { + const location = citation.location ? this._mapBedrockCitationLocation(citation.location) : undefined + if (!location) return + events.push({ + type: 'modelContentBlockDeltaEvent', + delta: { + type: 'citationsDelta', + citations: [ + { + location, + sourceContent: (citation.sourceContent ?? []).map((sc) => ({ text: sc.text! })), + source: citation.source ?? '', + title: citation.title ?? '', + }, + ], + content: [], + }, + }) + }, + } + + for (const key in delta) { + if (key in deltaHandlers) { + const handlerKey = key as keyof typeof deltaHandlers + // @ts-expect-error - We know the value type corresponds to the handler key. + deltaHandlers[handlerKey](delta[handlerKey]) + } else { + logger.warn(`delta_key=<${key}> | skipping unsupported delta key`) + } + } + + break + } + + case 'contentBlockStop': { + events.push({ + type: 'modelContentBlockStopEvent', + }) + break + } + + case 'messageStop': { + const data = eventData as BedrockMessageStopEvent + + const stopReasonRaw = ensureDefined(data.stopReason, 'messageStop.stopReason') as string + stopReason = stopReasonRaw + const event: ModelStreamEvent = { + type: 'modelMessageStopEvent', + stopReason: this._transformStopReason(stopReasonRaw, data), + } + + if (data.additionalModelResponseFields) { + event.additionalModelResponseFields = data.additionalModelResponseFields + } + + events.push(event) + break + } + + case 'metadata': { + const data = eventData as BedrockConverseStreamMetadataEvent + + const event: ModelStreamEvent = { + type: 'modelMetadataEvent', + } + + if (data.usage) { + const usage = data.usage + + const usageInfo: Usage = { + inputTokens: ensureDefined(usage.inputTokens, 'usage.inputTokens'), + outputTokens: ensureDefined(usage.outputTokens, 'usage.outputTokens'), + totalTokens: ensureDefined(usage.totalTokens, 'usage.totalTokens'), + } + + if (usage.cacheReadInputTokens !== undefined) { + usageInfo.cacheReadInputTokens = usage.cacheReadInputTokens + } + if (usage.cacheWriteInputTokens !== undefined) { + usageInfo.cacheWriteInputTokens = usage.cacheWriteInputTokens + } + + event.usage = usageInfo + } + + if (data.metrics) { + event.metrics = { + latencyMs: ensureDefined(data.metrics.latencyMs, 'metrics.latencyMs'), + } + } + + if (data.trace) { + event.trace = data.trace + + // Check for blocked guardrails in trace and emit redaction events + if (this._config.guardrailConfig && data.trace.guardrail && lastStopReason === 'guardrail_intervened') { + for (const redactionEvent of this._generateRedactionEvents(data.trace.guardrail)) { + events.push(redactionEvent) + } + } + } + + events.push(event) + break + } + case 'internalServerException': + case 'modelStreamErrorException': + case 'serviceUnavailableException': + case 'validationException': { + throw eventData + } + case 'throttlingException': { + const message = (eventData as { message?: string }).message ?? 'Request was throttled by the model provider' + logger.debug(`throttled | error_message=<${message}>`) + throw new ModelThrottledError(message, { cause: eventData }) + } + default: + // Log warning for unsupported event types (for forward compatibility) + logger.warn(`event_type=<${eventType}> | unsupported bedrock event type`) + break + } + + return stopReason !== undefined ? { events, stopReason } : { events } + } + + /** + * Transforms a Bedrock stop reason into the SDK's format. + * + * @param stopReasonRaw - The raw stop reason string from Bedrock. + * @param event - The full event output, used to check for tool_use adjustments. + * @returns The transformed stop reason. + */ + private _transformStopReason( + stopReasonRaw: string, + event?: ConverseCommandOutput | BedrockMessageStopEvent + ): StopReason { + let mappedStopReason: StopReason + + if (stopReasonRaw in STOP_REASON_MAP) { + mappedStopReason = STOP_REASON_MAP[stopReasonRaw as keyof typeof STOP_REASON_MAP] + } else { + const camelCaseReason = snakeToCamel(stopReasonRaw) + logger.warn( + `stop_reason=<${stopReasonRaw}>, fallback=<${camelCaseReason}> | unknown stop reason, converting to camelCase` + ) + mappedStopReason = camelCaseReason + } + + // Adjust for tool_use, which is sometimes incorrectly reported as end_turn + if ( + mappedStopReason === 'endTurn' && + event && + 'output' in event && + event.output?.message?.content?.some((block) => 'toolUse' in block) + ) { + mappedStopReason = 'toolUse' + logger.warn('stop_reason= | adjusting to tool_use due to tool use in content blocks') + } + + return mappedStopReason + } + + /** + * Maps a Bedrock object-key citation location to the SDK's type-field format. + * + * Bedrock uses object-key discrimination (`{ documentChar: { ... } }`) while the SDK uses + * type-field discrimination (`{ type: 'documentChar', ... }`). Also normalizes Bedrock's + * `searchResultLocation` key to the shorter `searchResult`. + * + * @param bedrockLocation - Bedrock citation location with object-key discrimination + * @returns SDK CitationLocation with type field discrimination + */ + private _mapBedrockCitationLocation(bedrockLocation: BedrockCitationLocation): CitationLocation | undefined { + if (bedrockLocation.documentChar) { + const loc = bedrockLocation.documentChar + return { type: 'documentChar', documentIndex: loc.documentIndex!, start: loc.start!, end: loc.end! } + } + if (bedrockLocation.documentPage) { + const loc = bedrockLocation.documentPage + return { type: 'documentPage', documentIndex: loc.documentIndex!, start: loc.start!, end: loc.end! } + } + if (bedrockLocation.documentChunk) { + const loc = bedrockLocation.documentChunk + return { type: 'documentChunk', documentIndex: loc.documentIndex!, start: loc.start!, end: loc.end! } + } + if (bedrockLocation.searchResultLocation) { + const loc = bedrockLocation.searchResultLocation + return { type: 'searchResult', searchResultIndex: loc.searchResultIndex!, start: loc.start!, end: loc.end! } + } + if (bedrockLocation.web) { + const loc = bedrockLocation.web + return { type: 'web', url: loc.url!, ...(loc.domain && { domain: loc.domain }) } + } + logger.warn(`citation_location=<${JSON.stringify(bedrockLocation)}> | unknown citation location type`) + return undefined + } + + /** + * Maps a Bedrock CitationsContentBlock to SDK CitationsBlockData. + * + * @param bedrockData - Bedrock CitationsContentBlock + * @returns SDK CitationsBlockData with type-field CitationLocations + */ + private _mapBedrockCitationsData(bedrockData: BedrockCitationsContentBlock): CitationsBlockData { + return { + citations: (bedrockData.citations ?? []) + .map((citation) => { + const location = citation.location ? this._mapBedrockCitationLocation(citation.location) : undefined + if (!location) return undefined + return { + source: citation.source ?? '', + title: citation.title ?? '', + sourceContent: (citation.sourceContent ?? []).map((sc) => ({ text: sc.text! })), + location, + } + }) + .filter((c) => c !== undefined), + content: (bedrockData.content ?? []).map((gc) => ({ text: gc.text! })), + } + } + + /** + * Maps an SDK Citation to Bedrock's Citation format. + * + * @param citation - SDK Citation with type-field location + * @returns Bedrock Citation with object-key location + */ + private _mapCitationToBedrock(citation: Citation): BedrockCitation { + return { + location: this._mapCitationLocationToBedrock(citation.location), + sourceContent: citation.sourceContent.map((sc) => ({ text: sc.text })), + source: citation.source, + title: citation.title, + } + } + + /** + * Maps an SDK CitationLocation to Bedrock's object-key format. + * + * @param location - SDK CitationLocation with type field + * @returns Bedrock CitationLocation with object-key discrimination + */ + private _mapCitationLocationToBedrock(location: CitationLocation): BedrockCitationLocation { + switch (location.type) { + case 'documentChar': { + const { type: _, ...fields } = location + return { documentChar: fields } + } + case 'documentPage': { + const { type: _, ...fields } = location + return { documentPage: fields } + } + case 'documentChunk': { + const { type: _, ...fields } = location + return { documentChunk: fields } + } + case 'searchResult': { + const { type: _, ...fields } = location + return { searchResultLocation: fields } + } + case 'web': + return { web: { url: location.url, ...(location.domain && { domain: location.domain }) } } + default: + return location as unknown as BedrockCitationLocation + } + } + + /** + * Generate redaction events based on guardrail configuration. + * + * @param guardrailData - The guardrail trace assessment data + * @returns Array of redaction events to emit + */ + private _generateRedactionEvents(guardrailData: GuardrailTraceAssessment): ModelStreamEvent[] { + const events: ModelStreamEvent[] = [] + const redaction = this._config.guardrailConfig?.redaction + + // Default: redact input is true unless explicitly set to false + if (redaction?.input !== false) { + logger.debug('redacting input due to guardrail') + events.push({ + type: 'modelRedactionEvent', + inputRedaction: { + replaceContent: redaction?.inputMessage ?? DEFAULT_REDACT_INPUT_MESSAGE, + }, + }) + } + + // Only redact output if explicitly enabled + if (redaction?.output) { + logger.debug('redacting output due to guardrail') + const outputRedactionEvent: ModelStreamEvent = { + type: 'modelRedactionEvent', + outputRedaction: { + replaceContent: redaction?.outputMessage ?? DEFAULT_REDACT_OUTPUT_MESSAGE, + }, + } + + // Include the original model output if available + if (guardrailData.modelOutput && guardrailData.modelOutput.length > 0) { + outputRedactionEvent.outputRedaction!.redactedContent = guardrailData.modelOutput.join('') + } + + events.push(outputRedactionEvent) + } + + return events + } +} + +/** + * Merges a default request timeout into the caller's requestHandler options. + * + * The SDK's `requestHandler` slot accepts either a constructed handler instance + * or an options bag that the SDK uses to build its default handler. We only + * inject a default in the options-bag case: a handler instance has its timeouts + * baked in at construction time, so we pass it through untouched. + * + * The handler-vs-options discriminator mirrors the SDK's own check — see + * `NodeHttp2Handler.create` in `@smithy/node-http-handler`. + */ +function withDefaultRequestTimeout( + handler: BedrockRuntimeClientConfig['requestHandler'] +): NonNullable { + if (handler && typeof (handler as { handle?: unknown }).handle === 'function') { + return handler + } + const options = (handler ?? {}) as { requestTimeout?: number; [key: string]: unknown } + // Use `??` rather than spread order so an explicit `requestTimeout: undefined` still gets + // the default (spread would otherwise overwrite the default with `undefined`, disabling it). + return { ...options, requestTimeout: options.requestTimeout ?? DEFAULT_REQUEST_TIMEOUT_MS } +} + +/** + * Adds middleware to override the Authorization header with a Bearer token. + * Runs after SigV4 signing to replace the signature with the API key. + * + * @param client - BedrockRuntimeClient instance + * @param apiKey - Bedrock API key + */ +function applyApiKey(client: BedrockRuntimeClient, apiKey: string): void { + client.middlewareStack.add( + // eslint-disable-next-line @typescript-eslint/explicit-function-return-type + (next) => async (args) => { + const request = args.request as { headers: Record } + request.headers['authorization'] = `Bearer ${apiKey}` + return next(args) + }, + { + step: 'finalizeRequest', + priority: 'low', + name: 'bedrockApiKeyMiddleware', + } + ) +} + +/** + * What region is used for the BedrockConfiguration can't be known at construction-time so to apply a default + * we have to use an async function to intercept "Region is missing" errors and then apply our default (this + * is actually how many bedrock configuration parameters are implemented). + * + * We need to override both region & useFipsEndpoint because the region is used in both of those places: + * https://github.com/smithy-lang/smithy-typescript/blob/e11f7499c1bad30a515217f82a07b9e3e69a1f60/packages/config-resolver/src/regionConfig/resolveRegionConfig.ts#L42 + * + * We do this unconditionally so that if a region is updated dynamically (environment variable or profile value) we + * also pick up those changes and stop applying the default. + */ +function applyDefaultRegion(config: BedrockRuntimeClientResolvedConfig): void { + // Bind original region function and wrap with error handling + const originalRegion = config.region.bind(config) + config.region = async (): Promise => { + try { + return await originalRegion() + } catch (error) { + // Note: it was observed that the browser version of the BedrockClient + // uses a string instead of an error object - thus the normalizeError call + if (normalizeError(error).message === 'Region is missing') { + return MODEL_DEFAULTS.bedrock.region + } + + throw error + } + } + + // Bind original useFipsEndpoint function and wrap with error handling + const originalUseFipsEndpoint = config.useFipsEndpoint.bind(config) + config.useFipsEndpoint = async (): Promise => { + try { + return await originalUseFipsEndpoint() + } catch (error) { + // Note: it was observed that the browser version of the BedrockClient + // uses a string instead of an error object - thus the normalizeError call + if (normalizeError(error).message === 'Region is missing') { + return DEFAULT_BEDROCK_REGION_SUPPORTS_FIP + } + throw error + } + } +} diff --git a/strands-ts/src/models/defaults.ts b/strands-ts/src/models/defaults.ts new file mode 100644 index 0000000000..7cab7dff97 --- /dev/null +++ b/strands-ts/src/models/defaults.ts @@ -0,0 +1,180 @@ +/** + * Default values for model providers. + * + * These defaults are subject to change between versions. Set values explicitly + * on model configurations to pin behavior across upgrades. + */ + +export const MODEL_DEFAULTS = { + anthropic: { + modelId: 'claude-sonnet-4-6', + maxTokens: 64_000, + }, + bedrock: { + modelId: 'global.anthropic.claude-sonnet-4-6', + region: 'us-west-2', + }, + openai: { + modelId: 'gpt-5.4', + }, + gemini: { + modelId: 'gemini-2.5-flash', + }, +} as const + +/** + * Builds a warning message for when the default model ID is used. + * + * @param defaultModelId - The default model ID being used + * @returns Formatted warning message string + */ +export function defaultModelWarningMessage(defaultModelId: string): string { + return `model_id=<${defaultModelId}> | using default modelId, which is subject to change | set modelId explicitly to pin the value` +} + +/** + * Builds a warning message for when the default max tokens value is used. + * + * @param defaultMaxTokens - The default max tokens value being used + * @returns Formatted warning message string + */ +export function defaultMaxTokensWarningMessage(defaultMaxTokens: number): string { + return `max_tokens=<${defaultMaxTokens}> | using default maxTokens, which is subject to change | set maxTokens explicitly to pin the value` +} + +/** + * Context window limits (in tokens) for known model IDs. + * + * Best-effort lookup table — unknown models return `undefined` and callers + * fall back gracefully (e.g. proactive compression is disabled). + * Entries can be pruned when a model is no longer available from the provider. + * Users can always override with an explicit `contextWindowLimit` in their model config. + * + * Values sourced from provider documentation and + * https://github.com/BerriAI/litellm/blob/litellm_internal_staging/model_prices_and_context_window.json + * + * For Bedrock models with cross-region prefixes (e.g. `us.`, `eu.`, `global.`), + * {@link getContextWindowLimit} strips the prefix before lookup so only the base model ID is needed here. + */ +const CONTEXT_WINDOW_LIMITS: Record = { + // Anthropic (direct API) + 'claude-sonnet-4-6': 1_000_000, + 'claude-sonnet-4-20250514': 1_000_000, + 'claude-sonnet-4-5': 200_000, + 'claude-sonnet-4-5-20250929': 200_000, + 'claude-opus-4-6': 1_000_000, + 'claude-opus-4-6-20260205': 1_000_000, + 'claude-opus-4-7': 1_000_000, + 'claude-opus-4-7-20260416': 1_000_000, + 'claude-opus-4-5': 200_000, + 'claude-opus-4-5-20251101': 200_000, + 'claude-opus-4-20250514': 200_000, + 'claude-opus-4-1': 200_000, + 'claude-opus-4-1-20250805': 200_000, + 'claude-haiku-4-5': 200_000, + 'claude-haiku-4-5-20251001': 200_000, + 'claude-3-7-sonnet-20250219': 200_000, + 'claude-3-5-sonnet-20241022': 200_000, + 'claude-3-5-sonnet-20240620': 200_000, + 'claude-3-5-haiku-20241022': 200_000, + 'claude-3-opus-20240229': 200_000, + 'claude-3-haiku-20240307': 200_000, + + // Bedrock Anthropic (base model IDs — cross-region prefixes stripped by getContextWindowLimit) + 'anthropic.claude-sonnet-4-6': 1_000_000, + 'anthropic.claude-sonnet-4-20250514-v1:0': 1_000_000, + 'anthropic.claude-sonnet-4-5-20250929-v1:0': 200_000, + 'anthropic.claude-opus-4-6-v1': 1_000_000, + 'anthropic.claude-opus-4-7': 1_000_000, + 'anthropic.claude-opus-4-5-20251101-v1:0': 200_000, + 'anthropic.claude-opus-4-20250514-v1:0': 200_000, + 'anthropic.claude-opus-4-1-20250805-v1:0': 200_000, + 'anthropic.claude-haiku-4-5-20251001-v1:0': 200_000, + 'anthropic.claude-haiku-4-5@20251001': 200_000, + 'anthropic.claude-3-7-sonnet-20250219-v1:0': 200_000, + 'anthropic.claude-3-7-sonnet-20240620-v1:0': 200_000, + 'anthropic.claude-3-5-sonnet-20241022-v2:0': 200_000, + 'anthropic.claude-3-5-sonnet-20240620-v1:0': 200_000, + 'anthropic.claude-3-5-haiku-20241022-v1:0': 200_000, + 'anthropic.claude-3-opus-20240229-v1:0': 200_000, + 'anthropic.claude-3-haiku-20240307-v1:0': 200_000, + 'anthropic.claude-3-sonnet-20240229-v1:0': 200_000, + 'anthropic.claude-mythos-preview': 1_000_000, + + // Bedrock Amazon Nova + 'amazon.nova-pro-v1:0': 300_000, + 'amazon.nova-lite-v1:0': 300_000, + 'amazon.nova-micro-v1:0': 128_000, + 'amazon.nova-premier-v1:0': 1_000_000, + 'amazon.nova-2-lite-v1:0': 1_000_000, + 'amazon.nova-2-pro-preview-20251202-v1:0': 1_000_000, + + // OpenAI + 'gpt-5.5': 1_050_000, + 'gpt-5.5-pro': 1_050_000, + 'gpt-5.4': 1_050_000, + 'gpt-5.4-pro': 1_050_000, + 'gpt-5.4-mini': 272_000, + 'gpt-5.4-nano': 272_000, + 'gpt-5.2': 272_000, + 'gpt-5.2-pro': 272_000, + 'gpt-5.1': 272_000, + 'gpt-5': 272_000, + 'gpt-5-mini': 272_000, + 'gpt-5-nano': 272_000, + 'gpt-5-pro': 128_000, + 'gpt-4.1': 1_047_576, + 'gpt-4.1-mini': 1_047_576, + 'gpt-4.1-nano': 1_047_576, + 'gpt-4o': 128_000, + 'gpt-4o-mini': 128_000, + 'gpt-4-turbo': 128_000, + o3: 200_000, + 'o3-mini': 200_000, + 'o3-pro': 200_000, + 'o4-mini': 200_000, + o1: 200_000, + + // Google Gemini + 'gemini-2.5-flash': 1_048_576, + 'gemini-2.5-flash-lite': 1_048_576, + 'gemini-2.5-pro': 1_048_576, + 'gemini-2.0-flash': 1_048_576, + 'gemini-2.0-flash-lite': 1_048_576, + 'gemini-3-pro-preview': 1_048_576, + 'gemini-3-flash-preview': 1_048_576, + 'gemini-3.1-pro-preview': 1_048_576, + 'gemini-3.1-flash-lite-preview': 1_048_576, +} + +/** + * Known Bedrock cross-region routing prefixes. + * + * @see https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference.html + */ +const BEDROCK_REGION_PREFIXES = new Set(['us', 'eu', 'ap', 'global', 'apac', 'au', 'jp', 'us-gov']) + +/** + * Looks up the context window limit for a model ID. + * + * For Bedrock cross-region model IDs (e.g. `us.anthropic.claude-sonnet-4-6`), + * the region prefix is stripped before lookup. + * + * @param modelId - The model ID to look up + * @returns The context window limit in tokens, or undefined if not found + */ +export function getContextWindowLimit(modelId: string): number | undefined { + const direct = CONTEXT_WINDOW_LIMITS[modelId] + if (direct !== undefined) return direct + + // Strip known Bedrock cross-region prefixes + const dotIndex = modelId.indexOf('.') + if (dotIndex !== -1) { + const prefix = modelId.substring(0, dotIndex) + if (BEDROCK_REGION_PREFIXES.has(prefix)) { + return CONTEXT_WINDOW_LIMITS[modelId.substring(dotIndex + 1)] + } + } + + return undefined +} diff --git a/strands-ts/src/models/google/adapters.ts b/strands-ts/src/models/google/adapters.ts new file mode 100644 index 0000000000..6cfe79605c --- /dev/null +++ b/strands-ts/src/models/google/adapters.ts @@ -0,0 +1,492 @@ +/** + * Adapters for converting between Strands SDK types and Gemini API format. + * + * @internal This module is not part of the public API. + */ + +import { + type Content, + type GenerateContentResponse, + type Part, + FunctionResponse, + FinishReason as GeminiFinishReason, +} from '@google/genai' +import type { + Message, + StopReason, + ContentBlock, + ReasoningBlock, + ToolUseBlock, + ToolResultBlock, +} from '../../types/messages.js' +import type { ModelStreamEvent } from '../streaming.js' +import type { GoogleStreamState } from './types.js' +import { encodeBase64, type ImageBlock, type DocumentBlock, type VideoBlock } from '../../types/media.js' +import { toMimeType } from '../../mime.js' +import { logger } from '../../logging/logger.js' + +/** + * Mapping of Gemini finish reasons to SDK stop reasons. + * Only MAX_TOKENS needs explicit mapping; everything else defaults to endTurn. + * Tool use stop reason is determined by the hasToolCalls flag in GoogleStreamState, + * since Gemini does not have a tool use finish reason. + * + * @internal + */ +export const FINISH_REASON_MAP: Partial> = { + [GeminiFinishReason.MAX_TOKENS]: 'maxTokens', +} + +// ============================================================================= +// Strands → Gemini +// ============================================================================= + +/** + * Formats an array of messages for the Gemini API. + * + * @param messages - SDK messages to format + * @returns Gemini-formatted contents array + * + * @internal + */ +export function formatMessages(messages: Message[]): Content[] { + const contents: Content[] = [] + + // Build toolUseId → name mapping for resolving tool result names + const toolUseIdToName = new Map() + for (const message of messages) { + for (const block of message.content) { + if (block.type === 'toolUseBlock') { + toolUseIdToName.set(block.toolUseId, block.name) + } + } + } + + for (const message of messages) { + const parts: Part[] = [] + + for (const block of message.content) { + parts.push(...formatContentBlock(block, toolUseIdToName)) + } + + if (parts.length > 0) { + contents.push({ + role: message.role === 'assistant' ? 'model' : 'user', + parts, + }) + } + } + + return contents +} + +/** + * Formats a content block to Gemini Parts. + * + * @param block - SDK content block + * @returns Array of Gemini Parts + * + * @internal + */ +function formatContentBlock(block: ContentBlock, toolUseIdToName: Map): Part[] { + switch (block.type) { + case 'textBlock': + return [{ text: block.text }] + + case 'imageBlock': + return formatImageBlock(block) + + case 'reasoningBlock': + return formatReasoningBlock(block) + + case 'documentBlock': + return formatDocumentBlock(block) + + case 'videoBlock': + return formatVideoBlock(block) + + case 'toolUseBlock': + return formatToolUseBlock(block) + + case 'toolResultBlock': + return formatToolResultBlock(block, toolUseIdToName) + + case 'cachePointBlock': + logger.warn('block_type= | cache points not supported by gemini, skipping') + return [] + + case 'guardContentBlock': + logger.warn('block_type= | guard content not supported by gemini, skipping') + return [] + + default: + return [] + } +} + +/** + * Formats an image block to Gemini Parts. + * + * @param block - Image block to format + * @returns Array of Gemini Parts + * + * @internal + */ +function formatImageBlock(block: ImageBlock): Part[] { + const mimeType = toMimeType(block.format) ?? `image/${block.format}` + + switch (block.source.type) { + case 'imageSourceBytes': + return [{ inlineData: { data: encodeBase64(block.source.bytes), mimeType } }] + + case 'imageSourceUrl': + return [{ fileData: { fileUri: block.source.url, mimeType } }] + + case 'imageSourceS3Location': + logger.warn('source_type= | s3 sources not supported by gemini, skipping') + return [] + + default: + return [] + } +} + +/** + * Formats a reasoning block to Gemini Parts. + * + * @param block - Reasoning block to format + * @returns Array of Gemini Parts + * + * @internal + */ +function formatReasoningBlock(block: ReasoningBlock): Part[] { + if (!block.text) { + return [] + } + + const part: Part = { + text: block.text, + thought: true, + } + + // Add thought signature if present + if (block.signature) { + part.thoughtSignature = block.signature + } + + return [part] +} + +/** + * Formats a document block to Gemini Parts. + * + * @param block - Document block to format + * @returns Array of Gemini Parts + * + * @internal + */ +function formatDocumentBlock(block: DocumentBlock): Part[] { + const mimeType = toMimeType(block.format) ?? `application/${block.format}` + + switch (block.source.type) { + case 'documentSourceBytes': + return [{ inlineData: { data: encodeBase64(block.source.bytes), mimeType } }] + + case 'documentSourceText': + // Convert text to bytes - Gemini API doesn't accept text directly + return [{ inlineData: { data: encodeBase64(new TextEncoder().encode(block.source.text)), mimeType } }] + + case 'documentSourceContentBlock': + return block.source.content.map((contentBlock) => ({ text: contentBlock.text })) + + case 'documentSourceS3Location': + logger.warn('source_type= | s3 sources not supported by gemini, skipping') + return [] + + default: + return [] + } +} + +/** + * Formats a video block to Gemini Parts. + * + * @param block - Video block to format + * @returns Array of Gemini Parts + * + * @internal + */ +function formatVideoBlock(block: VideoBlock): Part[] { + const mimeType = toMimeType(block.format) ?? `video/${block.format}` + + switch (block.source.type) { + case 'videoSourceBytes': + return [{ inlineData: { data: encodeBase64(block.source.bytes), mimeType } }] + + case 'videoSourceS3Location': + logger.warn('source_type= | s3 sources not supported by gemini, skipping') + return [] + + default: + return [] + } +} + +/** + * Formats a tool use block to a Gemini Part. + * + * @param block - Tool use block to format + * @returns Array of Gemini Parts with functionCall + * + * @internal + */ +function formatToolUseBlock(block: ToolUseBlock): Part[] { + return [ + { + functionCall: { + id: block.toolUseId, + name: block.name, + args: block.input as Record, + }, + ...(block.reasoningSignature && { thoughtSignature: block.reasoningSignature }), + }, + ] +} + +/** + * Formats a tool result block to a Gemini Part. + * + * @param block - Tool result block to format + * @param toolUseIdToName - Mapping from tool use IDs to tool names + * @returns Array of Gemini Parts with functionResponse + * + * @internal + */ +function formatToolResultBlock(block: ToolResultBlock, toolUseIdToName: Map): Part[] { + const parts: Part[] = [] + const output: Array<{ text?: string; json?: unknown }> = [] + + for (const c of block.content) { + switch (c.type) { + case 'textBlock': + output.push({ text: c.text }) + break + case 'jsonBlock': + output.push({ json: c.json }) + break + case 'imageBlock': { + const mimeType = toMimeType(c.format) ?? `image/${c.format}` + if (c.source.type === 'imageSourceBytes') { + parts.push({ + inlineData: { + data: encodeBase64(c.source.bytes), + mimeType, + displayName: `image.${c.format}`, + }, + }) + } else { + logger.warn('source_type=<%s> | only bytes sources supported in gemini tool results', c.source.type) + } + break + } + case 'documentBlock': { + const mimeType = toMimeType(c.format) ?? `application/${c.format}` + if (c.source.type === 'documentSourceBytes') { + parts.push({ + inlineData: { + data: encodeBase64(c.source.bytes), + mimeType, + displayName: c.name, + }, + }) + } else if (c.source.type === 'documentSourceText') { + parts.push({ + inlineData: { + data: encodeBase64(new TextEncoder().encode(c.source.text)), + mimeType, + displayName: c.name, + }, + }) + } else { + logger.warn('source_type=<%s> | only bytes/text sources supported in gemini tool results', c.source.type) + } + break + } + case 'videoBlock': + logger.warn('block_type= | videos not supported in gemini tool results, skipping') + break + } + } + + const functionResponse = new FunctionResponse() + functionResponse.id = block.toolUseId + functionResponse.name = toolUseIdToName.get(block.toolUseId) ?? block.toolUseId + functionResponse.response = { output } + if (parts.length > 0) { + functionResponse.parts = parts + } + + return [{ functionResponse }] +} + +// ============================================================================= +// Gemini → Strands +// ============================================================================= + +/** + * Maps a Gemini response chunk to SDK streaming events. + * + * @param chunk - Gemini response chunk + * @param streamState - Mutable state object tracking message and content block state + * @returns Array of SDK streaming events + * + * @internal + */ +export function mapChunkToEvents(chunk: GenerateContentResponse, streamState: GoogleStreamState): ModelStreamEvent[] { + const events: ModelStreamEvent[] = [] + + // Extract usage metadata if available + if (chunk.usageMetadata) { + const promptTokens = chunk.usageMetadata.promptTokenCount || 0 + const totalTokens = chunk.usageMetadata.totalTokenCount || 0 + streamState.inputTokens = promptTokens + streamState.outputTokens = totalTokens - promptTokens + } + + const candidates = chunk.candidates + if (!candidates || candidates.length === 0) { + return events + } + + const candidate = candidates[0] + if (!candidate) { + return events + } + + // Handle message start + if (!streamState.messageStarted) { + streamState.messageStarted = true + events.push({ + type: 'modelMessageStartEvent', + role: 'assistant', + }) + } + + // Process content parts + const content = candidate.content + if (content && content.parts) { + for (const part of content.parts) { + // Handle function call parts + if (part.functionCall) { + // Close any open text/reasoning blocks before tool use + if (streamState.textContentBlockStarted) { + events.push({ type: 'modelContentBlockStopEvent' }) + streamState.textContentBlockStarted = false + } + if (streamState.reasoningContentBlockStarted) { + events.push({ type: 'modelContentBlockStopEvent' }) + streamState.reasoningContentBlockStarted = false + } + + const toolUseId = part.functionCall.id || `tooluse_${globalThis.crypto.randomUUID()}` + + events.push({ + type: 'modelContentBlockStartEvent', + start: { + type: 'toolUseStart', + name: part.functionCall.name!, + toolUseId, + ...(part.thoughtSignature && { reasoningSignature: part.thoughtSignature }), + }, + }) + + events.push({ + type: 'modelContentBlockDeltaEvent', + delta: { + type: 'toolUseInputDelta', + input: JSON.stringify(part.functionCall.args ?? {}), + }, + }) + + events.push({ type: 'modelContentBlockStopEvent' }) + + streamState.hasToolCalls = true + continue + } + + // Handle text and reasoning parts + if ('text' in part && part.text) { + const isThought = 'thought' in part && part.thought === true + + if (isThought) { + // Handle reasoning content + // Close text block if transitioning from text to reasoning + if (streamState.textContentBlockStarted) { + events.push({ type: 'modelContentBlockStopEvent' }) + streamState.textContentBlockStarted = false + } + + if (!streamState.reasoningContentBlockStarted) { + streamState.reasoningContentBlockStarted = true + events.push({ type: 'modelContentBlockStartEvent' }) + } + + // Extract signature if present + const signature = part.thoughtSignature + + events.push({ + type: 'modelContentBlockDeltaEvent', + delta: { + type: 'reasoningContentDelta', + text: part.text, + ...(signature !== undefined && { signature }), + }, + }) + } else { + // Handle regular text content + // Close reasoning block if transitioning from reasoning to text + if (streamState.reasoningContentBlockStarted) { + events.push({ type: 'modelContentBlockStopEvent' }) + streamState.reasoningContentBlockStarted = false + } + + if (!streamState.textContentBlockStarted) { + streamState.textContentBlockStarted = true + events.push({ type: 'modelContentBlockStartEvent' }) + } + + events.push({ + type: 'modelContentBlockDeltaEvent', + delta: { + type: 'textDelta', + text: part.text, + }, + }) + } + } + } + } + + // Handle finish reason + const finishReason = candidate.finishReason + if (finishReason && finishReason !== GeminiFinishReason.FINISH_REASON_UNSPECIFIED) { + // Close any open content blocks + if (streamState.textContentBlockStarted) { + events.push({ type: 'modelContentBlockStopEvent' }) + streamState.textContentBlockStarted = false + } + if (streamState.reasoningContentBlockStarted) { + events.push({ type: 'modelContentBlockStopEvent' }) + streamState.reasoningContentBlockStarted = false + } + + const stopReason = streamState.hasToolCalls ? 'toolUse' : FINISH_REASON_MAP[finishReason] || 'endTurn' + + events.push({ + type: 'modelMessageStopEvent', + stopReason, + }) + } + + return events +} diff --git a/strands-ts/src/models/google/errors.ts b/strands-ts/src/models/google/errors.ts new file mode 100644 index 0000000000..c715367ee3 --- /dev/null +++ b/strands-ts/src/models/google/errors.ts @@ -0,0 +1,89 @@ +/** + * Error handling utilities for the Google model provider. + * + * @internal This module is not part of the public API. + */ + +import { logger } from '../../logging/logger.js' + +/** + * Recognized error types from Google GenAI API responses. + * + * This union type will expand as more error types are supported + * (e.g., 'throttling', 'invalidRequest'). + */ +export type GoogleErrorType = 'contextOverflow' | 'throttling' + +/** + * Configuration for handling a specific error status. + * If messagePatterns is provided, the error message must match one of the patterns. + * If messagePatterns is not provided, the status alone triggers the error type. + */ +export interface ErrorStatusConfig { + type: GoogleErrorType + messagePatterns?: Set +} + +/** + * Mapping of Google GenAI API error statuses to error handling configuration. + * Maps status codes to either direct error types or message-pattern-based detection. + */ +export const ERROR_STATUS_MAP: Record = { + INVALID_ARGUMENT: { + type: 'contextOverflow', + messagePatterns: new Set(['exceeds the maximum number of tokens']), + }, + RESOURCE_EXHAUSTED: { + type: 'throttling', + }, + UNAVAILABLE: { + type: 'throttling', + }, +} + +/** + * Classifies a Google GenAI API error based on status and message patterns. + * Returns the error type if recognized, undefined otherwise. + * + * @param error - The error to classify + * @returns The error type if recognized, undefined otherwise + * + * @internal + */ +export function classifyGoogleError(error: Error): GoogleErrorType | undefined { + if (!error.message) { + return undefined + } + + let status: string + let message: string + + try { + const parsed = JSON.parse(error.message) + status = parsed?.error?.status || '' + message = parsed?.error?.message || '' + } catch { + logger.debug(`error_message=<${error.message}> | google genai api returned non-json error`) + return undefined + } + + const config = ERROR_STATUS_MAP[status.toUpperCase()] + if (!config) { + return undefined + } + + // If no message patterns required, status alone determines the error type + if (!config.messagePatterns) { + return config.type + } + + // Check if message matches any of the patterns + const lowerMessage = message.toLowerCase() + for (const pattern of config.messagePatterns) { + if (lowerMessage.includes(pattern)) { + return config.type + } + } + + return undefined +} diff --git a/strands-ts/src/models/google/index.ts b/strands-ts/src/models/google/index.ts new file mode 100644 index 0000000000..e167595d28 --- /dev/null +++ b/strands-ts/src/models/google/index.ts @@ -0,0 +1,15 @@ +/** + * Google model provider. + * + * @example + * ```typescript + * import { GoogleModel } from '@strands-agents/sdk/models/google' + * + * const model = new GoogleModel({ + * apiKey: 'your-api-key', + * modelId: 'gemini-2.5-flash', + * }) + * ``` + */ + +export { GoogleModel, type GoogleModelConfig, type GoogleModelOptions } from './model.js' diff --git a/strands-ts/src/models/google/model.ts b/strands-ts/src/models/google/model.ts new file mode 100644 index 0000000000..9855ca1064 --- /dev/null +++ b/strands-ts/src/models/google/model.ts @@ -0,0 +1,355 @@ +/** + * Google model provider implementation. + * + * This module provides integration with Google's Gemini API, + * supporting streaming responses and configurable model parameters. + * + * @see https://ai.google.dev/docs + */ + +import { + GoogleGenAI, + FunctionCallingConfigMode, + type GenerateContentConfig, + type GenerateContentParameters, +} from '@google/genai' +import { Model, resolveConfigMetadata } from '../model.js' +import type { CountTokensOptions, StreamOptions } from '../model.js' +import type { Message } from '../../types/messages.js' +import type { ModelStreamEvent } from '../streaming.js' +import { ContextWindowOverflowError, ModelThrottledError, ProviderTokenCountError } from '../../errors.js' +import type { GoogleModelConfig, GoogleModelOptions, GoogleStreamState } from './types.js' +export type { GoogleModelConfig, GoogleModelOptions } +import { classifyGoogleError } from './errors.js' +import { formatMessages, mapChunkToEvents } from './adapters.js' +import { MODEL_DEFAULTS, defaultModelWarningMessage } from '../defaults.js' +import { warnOnce } from '../../logging/warn-once.js' +import { logger } from '../../logging/logger.js' + +/** + * Google model provider implementation. + * + * Implements the Model interface for Google GenAI using the Generative AI API. + * Supports streaming responses and comprehensive configuration. + * + * @example + * ```typescript + * const provider = new GoogleModel({ + * apiKey: 'your-api-key', + * modelId: 'gemini-2.5-flash', + * params: { temperature: 0.7, maxOutputTokens: 1024 } + * }) + * + * const messages: Message[] = [ + * { role: 'user', content: [{ type: 'textBlock', text: 'Hello!' }] } + * ] + * + * for await (const event of provider.stream(messages)) { + * if (event.type === 'modelContentBlockDeltaEvent' && event.delta.type === 'textDelta') { + * process.stdout.write(event.delta.text) + * } + * } + * ``` + */ +export class GoogleModel extends Model { + private _config: GoogleModelConfig + private _client: GoogleGenAI + + /** + * Creates a new GoogleModel instance. + * + * @param options - Configuration for model and client + * + * @example + * ```typescript + * // Minimal configuration with API key + * const provider = new GoogleModel({ + * apiKey: 'your-api-key' + * }) + * + * // With model configuration + * const provider = new GoogleModel({ + * apiKey: 'your-api-key', + * modelId: 'gemini-2.5-flash', + * params: { temperature: 0.8, maxOutputTokens: 2048 } + * }) + * + * // Using environment variable for API key + * const provider = new GoogleModel({ + * modelId: 'gemini-2.5-flash' + * }) + * + * // Using a pre-configured client instance + * const client = new GoogleGenAI({ apiKey: 'your-api-key' }) + * const provider = new GoogleModel({ + * client + * }) + * ``` + */ + constructor(options?: GoogleModelOptions) { + super() + const { apiKey, client, clientConfig, ...modelConfig } = options || {} + + this._config = modelConfig + + if (modelConfig.modelId === undefined) { + warnOnce(logger, defaultModelWarningMessage(MODEL_DEFAULTS.gemini.modelId)) + } + + if (client) { + this._client = client + } else { + const resolvedApiKey = apiKey || GoogleModel._getEnvApiKey() + + if (!resolvedApiKey) { + throw new Error( + "Gemini API key is required. Provide it via the 'apiKey' option or set the GEMINI_API_KEY environment variable." + ) + } + + this._client = new GoogleGenAI({ + apiKey: resolvedApiKey, + ...clientConfig, + }) + } + } + + /** + * Updates the model configuration. + * Merges the provided configuration with existing settings. + * + * @param modelConfig - Configuration object with model-specific settings to update + * + * @example + * ```typescript + * // Update model parameters + * provider.updateConfig({ + * params: { temperature: 0.9, maxOutputTokens: 2048 } + * }) + * ``` + */ + updateConfig(modelConfig: GoogleModelConfig): void { + this._config = { ...this._config, ...modelConfig } + } + + /** + * Retrieves the current model configuration. + * + * @returns The current configuration object + * + * @example + * ```typescript + * const config = provider.getConfig() + * console.log(config.modelId) + * ``` + */ + getConfig(): GoogleModelConfig { + return resolveConfigMetadata(this._config, this._config.modelId ?? MODEL_DEFAULTS.gemini.modelId) + } + + /** + * Count tokens using Gemini's native countTokens API. + * + * Uses the Gemini countTokens API for message contents. System instructions and tools + * are estimated via the base class heuristic because the Gemini API (non-Vertex backend) + * does not support these in CountTokensConfig. + * Falls back to the base class heuristic on failure. + * + * @param messages - Array of conversation messages to count tokens for + * @param options - Optional options containing system prompt and tool specs + * @returns Total input token count + */ + override async countTokens(messages: Message[], options?: CountTokensOptions): Promise { + if (this._config.useNativeTokenCount !== true) return super.countTokens(messages, options) + + try { + const params = this._formatRequest(messages, options) + const modelId = params.model + + // The Gemini API (non-Vertex backend) raises an error for systemInstruction and tools + // in CountTokensConfig. Use native counting for message contents only, then add the + // heuristic estimate for system prompt and tools. + const response = await this._client.models.countTokens({ + model: modelId, + contents: params.contents, + }) + + if (response.totalTokens == null) { + throw new ProviderTokenCountError('Gemini countTokens returned null for totalTokens') + } + + let totalTokens = response.totalTokens + + // Add heuristic estimate for system prompt and tools (not supported by the API) + if (options?.systemPrompt || options?.toolSpecs) { + totalTokens += await super.countTokens([], { + ...(options.systemPrompt && { systemPrompt: options.systemPrompt }), + ...(options.toolSpecs && { toolSpecs: options.toolSpecs }), + }) + } + + logger.debug(`total_tokens=<${totalTokens}> | native token count`) + return totalTokens + } catch (error) { + logger.debug(`error=<${error}> | native token counting failed, falling back to estimation`) + return super.countTokens(messages, options) + } + } + + /** + * Streams a conversation with the Google model. + * Returns an async iterable that yields streaming events as they occur. + * + * @param messages - Array of conversation messages + * @param options - Optional streaming configuration + * @returns Async iterable of streaming events + * + * @throws \{ContextWindowOverflowError\} When input exceeds the model's context window + * + * @example + * ```typescript + * const provider = new GoogleModel({ apiKey: 'your-api-key' }) + * const messages: Message[] = [ + * { role: 'user', content: [{ type: 'textBlock', text: 'What is 2+2?' }] } + * ] + * + * for await (const event of provider.stream(messages)) { + * if (event.type === 'modelContentBlockDeltaEvent' && event.delta.type === 'textDelta') { + * process.stdout.write(event.delta.text) + * } + * } + * ``` + */ + async *stream(messages: Message[], options?: StreamOptions): AsyncIterable { + if (!messages || messages.length === 0) { + throw new Error('At least one message is required') + } + + try { + const params = this._formatRequest(messages, options) + const stream = await this._client.models.generateContentStream(params) + + const streamState: GoogleStreamState = { + messageStarted: false, + textContentBlockStarted: false, + reasoningContentBlockStarted: false, + hasToolCalls: false, + inputTokens: 0, + outputTokens: 0, + } + + for await (const chunk of stream) { + yield* mapChunkToEvents(chunk, streamState) + } + + if (streamState.inputTokens > 0 || streamState.outputTokens > 0) { + yield { + type: 'modelMetadataEvent', + usage: { + inputTokens: streamState.inputTokens, + outputTokens: streamState.outputTokens, + totalTokens: streamState.inputTokens + streamState.outputTokens, + }, + } + } + } catch (error) { + if (!(error instanceof Error)) { + throw error + } + const errorType = classifyGoogleError(error) + + if (errorType === 'contextOverflow') { + throw new ContextWindowOverflowError(error.message) + } + + if (errorType === 'throttling') { + throw new ModelThrottledError(error.message, { cause: error }) + } + + throw error + } + } + + /** + * Gets API key from environment variables. + */ + private static _getEnvApiKey(): string | undefined { + return globalThis?.process?.env?.GEMINI_API_KEY + } + + /** + * Formats a request for the Google GenAI API. + */ + private _formatRequest(messages: Message[], options?: StreamOptions): GenerateContentParameters { + const contents = formatMessages(messages) + const config: GenerateContentConfig = {} + + // Add system instruction + if (options?.systemPrompt !== undefined) { + if (typeof options.systemPrompt === 'string') { + if (options.systemPrompt.trim().length > 0) { + config.systemInstruction = options.systemPrompt + } + } else if (Array.isArray(options.systemPrompt) && options.systemPrompt.length > 0) { + const textBlocks: string[] = [] + + for (const block of options.systemPrompt) { + if (block.type === 'textBlock') { + textBlocks.push(block.text) + } + } + + if (textBlocks.length > 0) { + config.systemInstruction = textBlocks.join('') + } + } + } + + // Add tool specifications + if (options?.toolSpecs && options.toolSpecs.length > 0) { + config.tools = [ + { + functionDeclarations: options.toolSpecs.map((spec) => ({ + name: spec.name, + description: spec.description, + parametersJsonSchema: spec.inputSchema, + })), + }, + ] + + if (options.toolChoice) { + if ('auto' in options.toolChoice) { + config.toolConfig = { functionCallingConfig: { mode: FunctionCallingConfigMode.AUTO } } + } else if ('any' in options.toolChoice) { + config.toolConfig = { functionCallingConfig: { mode: FunctionCallingConfigMode.ANY } } + } else if ('tool' in options.toolChoice) { + config.toolConfig = { + functionCallingConfig: { + mode: FunctionCallingConfigMode.ANY, + allowedFunctionNames: [options.toolChoice.tool.name], + }, + } + } + } + } + + // Append built-in tools (e.g., GoogleSearch, CodeExecution) + if (this._config.builtInTools && this._config.builtInTools.length > 0) { + if (!config.tools) { + config.tools = [] + } + config.tools.push(...this._config.builtInTools) + } + + // Spread params object for forward compatibility + if (this._config.params) { + Object.assign(config, this._config.params) + } + + return { + model: this._config.modelId ?? MODEL_DEFAULTS.gemini.modelId, + contents, + config, + } + } +} diff --git a/strands-ts/src/models/google/types.ts b/strands-ts/src/models/google/types.ts new file mode 100644 index 0000000000..686d704615 --- /dev/null +++ b/strands-ts/src/models/google/types.ts @@ -0,0 +1,89 @@ +/** + * Type definitions for the Google model provider. + */ + +import type { GoogleGenAI, GoogleGenAIOptions, Tool } from '@google/genai' +import type { BaseModelConfig } from '../model.js' + +/** + * Configuration interface for Google model provider. + * + * @example + * ```typescript + * const config: GoogleModelConfig = { + * modelId: 'gemini-2.5-flash', + * params: { temperature: 0.7, maxOutputTokens: 1024 } + * } + * ``` + * + * @see https://ai.google.dev/api/generate-content#generationconfig + */ +export interface GoogleModelConfig extends BaseModelConfig { + /** + * Gemini model identifier (e.g., gemini-2.5-flash, gemini-2.5-pro). + * + * @defaultValue 'gemini-2.5-flash' + * @see https://ai.google.dev/gemini-api/docs/models + */ + modelId?: string + + /** + * Additional parameters to pass to the Gemini API (e.g., temperature, maxOutputTokens). + * + * @see https://ai.google.dev/api/generate-content#generationconfig + */ + params?: Record + + /** + * Built-in tools (e.g., GoogleSearch, CodeExecution, UrlContext). + * These are appended as separate Tool objects alongside any functionDeclarations. + * + * @see https://ai.google.dev/gemini-api/docs/function-calling + */ + builtInTools?: Tool[] + + /** + * Whether to use the native Gemini countTokens API. + * + * When `true`, `countTokens()` calls the Gemini token counting API for + * accurate counts. When `false` or not set (default), skips the API call and uses + * the character-based heuristic estimator. + * + * @defaultValue false + */ + useNativeTokenCount?: boolean +} + +/** + * Options interface for creating a GoogleModel instance. + */ +export interface GoogleModelOptions extends GoogleModelConfig { + /** + * Gemini API key (falls back to GEMINI_API_KEY environment variable). + */ + apiKey?: string + + /** + * Pre-configured Google GenAI client instance. + * If provided, this client will be used instead of creating a new one. + */ + client?: GoogleGenAI + + /** + * Additional Google GenAI client configuration. + * Only used if client is not provided. + */ + clientConfig?: Omit +} + +/** + * Internal state for tracking streaming progress. + */ +export interface GoogleStreamState { + messageStarted: boolean + textContentBlockStarted: boolean + reasoningContentBlockStarted: boolean + hasToolCalls: boolean + inputTokens: number + outputTokens: number +} diff --git a/strands-ts/src/models/model.ts b/strands-ts/src/models/model.ts new file mode 100644 index 0000000000..892c46e97d --- /dev/null +++ b/strands-ts/src/models/model.ts @@ -0,0 +1,609 @@ +import { + type ContentBlock, + Message, + type MessageMetadata, + ReasoningBlock, + type Role, + type StopReason, + type SystemPrompt, + TextBlock, + ToolUseBlock, +} from '../types/messages.js' +import { CitationsBlock } from '../types/citations.js' +import type { Citation, CitationGeneratedContent } from '../types/citations.js' +import type { StateStore } from '../state-store.js' +import type { ToolChoice, ToolSpec } from '../tools/types.js' +import { + ModelContentBlockDeltaEvent, + ModelContentBlockStartEvent, + ModelContentBlockStopEvent, + ModelMessageStartEvent, + ModelMessageStopEvent, + ModelMetadataEvent, + ModelRedactionEvent, + type ModelStreamEvent, +} from './streaming.js' +import { MaxTokensError, ModelError, normalizeError } from '../errors.js' +import type { Redaction } from '../hooks/events.js' +import { logger } from '../logging/logger.js' +import { getContextWindowLimit } from './defaults.js' + +/** + * Resolves model metadata fields on a config object from built-in lookup tables + * when not explicitly set. Explicit values pass through unchanged. + * + * @internal + * @param config - The stored model config + * @param modelId - The model ID to look up + * @returns A new config with resolved metadata, or the original config if nothing to resolve + */ +export function resolveConfigMetadata(config: T, modelId: string): T { + if (config.contextWindowLimit !== undefined) return config + const limit = getContextWindowLimit(modelId) + if (limit === undefined) return config + return { ...config, contextWindowLimit: limit } +} + +class CitationAccumulator { + citations: Citation[] = [] + content: CitationGeneratedContent[] = [] + + push(citations: Citation[], content: CitationGeneratedContent[]): void { + this.citations.push(...citations) + this.content.push(...content) + } + + hasData(): boolean { + return this.citations.length > 0 + } + + reset(): void { + this.citations = [] + this.content = [] + } +} + +/** + * Configuration for prompt caching. + */ +export interface CacheConfig { + /** + * Caching strategy to use. + * - "auto": Automatically inject cache points at optimal positions based on model ID detection + * (after tools, after last user message) + * - "anthropic": Force enable Anthropic-style caching (useful for application inference profiles) + */ + strategy: 'auto' | 'anthropic' +} + +/** + * Base configuration interface for all model providers. + * + * This interface defines the common configuration properties that all + * model providers should support. Provider-specific configurations + * should extend this interface. + */ +export interface BaseModelConfig { + /** + * The model identifier. + * This typically specifies which model to use from the provider's catalog. + */ + modelId?: string + + /** + * Maximum number of tokens to generate in the response. + * + * @see Provider-specific documentation for exact behavior + */ + maxTokens?: number + + /** + * Controls randomness in generation. + * + * @see Provider-specific documentation for valid range + */ + temperature?: number + + /** + * Controls diversity via nucleus sampling. + * + * @see Provider-specific documentation for details + */ + topP?: number + + /** + * Maximum context window size in tokens for the model. + * + * This value represents the total token capacity shared between input and output. + * When not provided, it is automatically resolved from a built-in lookup table + * based on the configured model ID. An explicit value always takes precedence. + * + * When `modelId` is changed via `updateConfig()`, this value is automatically + * re-resolved if it was initially auto-populated. Explicitly set values are preserved. + */ + contextWindowLimit?: number +} + +/** + * Options interface for configuring streaming model invocation. + */ +export interface StreamOptions { + /** + * System prompt to guide the model's behavior. + * Can be a simple string or an array of content blocks for advanced caching. + */ + systemPrompt?: SystemPrompt + + /** + * Array of tool specifications that the model can use. + */ + toolSpecs?: ToolSpec[] + + /** + * Controls how the model selects tools to use. + */ + toolChoice?: ToolChoice + + /** + * Runtime state for model providers that manage server-side conversation state. + * The model can read and write this state during streaming (e.g., to store a + * response ID for conversation chaining). Mutations via `set`/`delete` are + * visible to the caller after the stream completes. + */ + modelState?: StateStore +} + +/** + * Options for counting tokens in a set of messages. + */ +export interface CountTokensOptions { + /** + * System prompt to guide the model's behavior. + * Can be a simple string or an array of content blocks for advanced caching. + */ + systemPrompt?: SystemPrompt + + /** + * Array of tool specifications to include in the count. + */ + toolSpecs?: ToolSpec[] +} + +/** + * Result interface for the streamAggregated method. + * Contains the complete message, stop reason, and optional metadata. + */ +export interface StreamAggregatedResult { + /** + * The complete message from the model. + */ + message: Message + + /** + * The reason why the model stopped generating. + */ + stopReason: StopReason + + /** + * Optional metadata about the model invocation, including usage statistics and metrics. + */ + metadata?: ModelMetadataEvent + + /** + * Optional redaction information when guardrails blocked input. + * Output redaction is handled by updating the message directly. + */ + redaction?: Redaction +} + +/** + * Base abstract class for model providers. + * Defines the contract that all model provider implementations must follow. + * + * Model providers handle communication with LLM APIs and implement streaming + * responses using async iterables. + * + * @typeParam T - Model configuration type extending BaseModelConfig + */ +export abstract class Model { + /** + * Updates the model configuration. + * Merges the provided configuration with existing settings. + * + * @param modelConfig - Configuration object with model-specific settings to update + */ + abstract updateConfig(modelConfig: T): void + + /** + * Retrieves the current model configuration. + * + * @returns The current configuration object + */ + abstract getConfig(): T + + /** + * The model ID from the current configuration, if configured. + */ + get modelId(): string | undefined { + return this.getConfig().modelId + } + + /** + * Whether this model manages conversation state server-side. + * + * When `true`, the server tracks conversation context across turns, so the SDK + * sends only the latest message instead of the full history. After each invocation, + * the agent's local message history is cleared automatically. + * + * Model providers that support server-side state management should override this + * to return `true`. + * + * @returns `false` by default + */ + get stateful(): boolean { + return false + } + + /** + * Streams a conversation with the model. + * Returns an async iterable that yields streaming events as they occur. + * + * @param messages - Array of conversation messages + * @param options - Optional streaming configuration + * @returns Async iterable of streaming events + */ + abstract stream(messages: Message[], options?: StreamOptions): AsyncIterable + + /** + * Count tokens for the given input before sending to the model. + * + * Used for proactive context management (e.g., triggering compression at a threshold). + * The base implementation uses a character-based heuristic (chars/4 for text, chars/2 for JSON). + * + * Subclasses should override this method to use native token counting APIs + * (e.g., Bedrock CountTokens, Anthropic countTokens, Gemini countTokens) + * for improved accuracy, falling back to `super.countTokens()` on API failure. + * + * @param messages - Array of conversation messages to count tokens for + * @param options - Optional options containing system prompt and tool specs + * @returns Total input token count + */ + async countTokens(messages: Message[], options?: CountTokensOptions): Promise { + return estimateTokensHeuristic(messages, options) + } + + /** + * Converts event data to event class representation + * + * @param event_data - Interface representation of event + * @returns Class representation of event + */ + private _convert_to_class_event(event_data: ModelStreamEvent): ModelStreamEvent { + switch (event_data.type) { + case 'modelMessageStartEvent': + return new ModelMessageStartEvent(event_data) + case 'modelContentBlockStartEvent': + return new ModelContentBlockStartEvent(event_data) + case 'modelContentBlockDeltaEvent': + return new ModelContentBlockDeltaEvent(event_data) + case 'modelContentBlockStopEvent': + return new ModelContentBlockStopEvent(event_data) + case 'modelMessageStopEvent': + return new ModelMessageStopEvent(event_data) + case 'modelMetadataEvent': + return new ModelMetadataEvent(event_data) + case 'modelRedactionEvent': + return new ModelRedactionEvent(event_data) + default: + throw new Error(`Unsupported event type: ${(event_data as { type: string }).type}`) + } + } + + /** + * Streams a conversation with aggregated content blocks and messages. + * Returns an async generator that yields streaming events and content blocks, and returns the final message with stop reason and optional metadata. + * + * This method enhances the basic stream() by collecting streaming events into complete + * ContentBlock and Message objects, which are needed by the agentic loop for tool execution + * and conversation management. + * + * The method yields: + * - ModelStreamEvent - Original streaming events (passed through) + * - ContentBlock - Complete content block (emitted when block completes) + * + * The method returns: + * - StreamAggregatedResult containing the complete message, stop reason, and optional metadata + * + * All exceptions thrown from this method are wrapped in ModelError to provide + * a consistent error type for model-related errors. Specific error subtypes like + * ContextWindowOverflowError, ModelThrottledError, and MaxTokensError are preserved. + * + * @param messages - Array of conversation messages + * @param options - Optional streaming configuration + * @returns Async generator yielding ModelStreamEvent | ContentBlock and returning a StreamAggregatedResult + * @throws ModelError - Base class for all model-related errors + * @throws ContextWindowOverflowError - When input exceeds the model's context window + * @throws ModelThrottledError - When the model provider throttles requests + * @throws MaxTokensError - When the model reaches its maximum token limit + */ + async *streamAggregated( + messages: Message[], + options?: StreamOptions + ): AsyncGenerator { + try { + // State maintained in closure + let messageRole: Role | null = null + const contentBlocks: ContentBlock[] = [] + let accumulatedText = '' + let accumulatedToolInput = '' + let toolName = '' + let toolUseId = '' + let toolReasoningSignature = '' + let accumulatedReasoning: { + text?: string + signature?: string + redactedContent?: Uint8Array + } = {} + const accumulatedCitations = new CitationAccumulator() + let stoppedMessage: Message | null = null + let finalStopReason: StopReason | null = null + let metadata: ModelMetadataEvent | undefined = undefined + let redactionMessage: string | undefined = undefined + + for await (const event_data of this.stream(messages, options)) { + const event = this._convert_to_class_event(event_data) + yield event // Pass through immediately + + // Aggregation logic based on event type + switch (event.type) { + case 'modelMessageStartEvent': + messageRole = event.role + contentBlocks.length = 0 // Reset + break + + case 'modelContentBlockStartEvent': + if (event.start?.type === 'toolUseStart') { + toolName = event.start.name + toolUseId = event.start.toolUseId + toolReasoningSignature = event.start.reasoningSignature ?? '' + } + accumulatedToolInput = '' + accumulatedText = '' + accumulatedReasoning = {} + accumulatedCitations.reset() + break + + case 'modelContentBlockDeltaEvent': { + switch (event.delta.type) { + case 'textDelta': + accumulatedText += event.delta.text + break + case 'toolUseInputDelta': + accumulatedToolInput += event.delta.input + break + case 'reasoningContentDelta': + if (event.delta.text) accumulatedReasoning.text = (accumulatedReasoning.text ?? '') + event.delta.text + if (event.delta.signature) accumulatedReasoning.signature = event.delta.signature + if (event.delta.redactedContent) accumulatedReasoning.redactedContent = event.delta.redactedContent + break + case 'citationsDelta': + accumulatedCitations.push(event.delta.citations, event.delta.content) + break + } + break + } + + case 'modelContentBlockStopEvent': { + // Finalize and emit complete ContentBlock + let block: ContentBlock + try { + if (toolUseId) { + block = new ToolUseBlock({ + name: toolName, + toolUseId: toolUseId, + input: accumulatedToolInput ? JSON.parse(accumulatedToolInput) : {}, + ...(toolReasoningSignature && { reasoningSignature: toolReasoningSignature }), + }) + toolUseId = '' // Reset + toolName = '' + toolReasoningSignature = '' + } else if (Object.keys(accumulatedReasoning).length > 0) { + block = new ReasoningBlock({ + ...accumulatedReasoning, + }) + accumulatedReasoning = {} // Reset after creating reasoning block + } else if (accumulatedCitations.hasData()) { + block = new CitationsBlock({ + citations: accumulatedCitations.citations, + content: accumulatedCitations.content, + }) + accumulatedCitations.reset() + } else { + block = new TextBlock(accumulatedText) + } + contentBlocks.push(block) + yield block + } catch (e: unknown) { + if (e instanceof SyntaxError) { + logger.error('unable to parse JSON string', e) + throw e + } + } + break + } + + case 'modelMessageStopEvent': + // Store message and stop reason + if (messageRole) { + stoppedMessage = new Message({ + role: messageRole, + content: [...contentBlocks], + }) + finalStopReason = event.stopReason! + } + break + + case 'modelMetadataEvent': + // Store metadata, keeping the last one if multiple events occur + metadata = event + break + + case 'modelRedactionEvent': + // Handle content redaction from guardrails + if (event.inputRedaction) { + // Store redaction message for agent to handle input message redaction + redactionMessage = event.inputRedaction.replaceContent + } + if (event.outputRedaction) { + // Update output message directly with redacted content + // Redaction event comes after modelMessageStopEvent, so we overwrite stoppedMessage + stoppedMessage = new Message({ + role: 'assistant', + content: [new TextBlock(event.outputRedaction.replaceContent)], + }) + } + break + + default: + break + } + } + + if (!stoppedMessage || !finalStopReason) { + // If we exit the loop without completing a message or stop reason, throw an error + throw new ModelError('Stream ended without completing a message') + } + + // Attach metadata after redaction so it applies to the final message. + const messageMetadata: MessageMetadata = { + ...(metadata?.usage !== undefined && { usage: metadata.usage }), + ...(metadata?.metrics !== undefined && { metrics: metadata.metrics }), + } + if (Object.keys(messageMetadata).length > 0) { + stoppedMessage = new Message({ + role: stoppedMessage.role, + content: stoppedMessage.content, + metadata: messageMetadata, + }) + } + + // Handle stop reason + if (finalStopReason === 'maxTokens') { + throw new MaxTokensError( + 'Model reached maximum token limit. This is an unrecoverable state that requires intervention.', + stoppedMessage + ) + } + + // Return the final message with stop reason and optional metadata + const result: StreamAggregatedResult = { + message: stoppedMessage, + stopReason: finalStopReason, + } + if (metadata !== undefined) { + result.metadata = metadata + } + if (redactionMessage !== undefined) { + result.redaction = { userMessage: redactionMessage } + } + return result + } catch (error) { + // Wrap non-ModelError errors in ModelError + if (error instanceof ModelError) { + throw error + } + const normalizedError = normalizeError(error) + throw new ModelError(normalizedError.message, { cause: error }) + } + } +} + +/** + * Estimate tokens for a content block using character-based heuristics. + * + * @param block - Content block to estimate tokens for + * @returns Estimated token count + */ +function estimateContentBlockTokens(block: ContentBlock): number { + let total = 0 + + switch (block.type) { + case 'textBlock': + total += heuristicText(block.text) + break + case 'toolUseBlock': + total += heuristicText(block.name) + total += heuristicJson(block.input) + break + case 'toolResultBlock': + for (const item of block.content) { + if (item.type === 'textBlock') { + total += heuristicText(item.text) + } else if (item.type === 'jsonBlock') { + total += heuristicJson(item.json) + } + } + break + case 'reasoningBlock': + if (block.text) total += heuristicText(block.text) + break + case 'guardContentBlock': + if (block.text) total += heuristicText(block.text.text) + break + case 'citationsBlock': + for (const item of block.content) { + if ('text' in item) total += heuristicText(item.text) + } + break + default: + break + } + + return total +} + +/** + * Estimate token count using character-based heuristics (text: chars/4, JSON: chars/2). + * Dependency-free fallback used by the base Model class. + */ +function estimateTokensHeuristic(messages: Message[], options?: CountTokensOptions): number { + let total = 0 + + if (options?.systemPrompt) { + if (typeof options.systemPrompt === 'string') { + total += heuristicText(options.systemPrompt) + } else { + for (const block of options.systemPrompt) { + if (block.type === 'textBlock') total += heuristicText(block.text) + else if (block.type === 'guardContentBlock' && block.text) total += heuristicText(block.text.text) + } + } + } + + for (const message of messages) { + for (const block of message.content) { + total += estimateContentBlockTokens(block) + } + } + + if (options?.toolSpecs) { + for (const spec of options.toolSpecs) { + total += heuristicJson(spec) + } + } + + return total +} + +function heuristicText(text: string): number { + return Math.ceil(text.length / 4) +} + +function heuristicJson(obj: unknown): number { + try { + return Math.ceil(JSON.stringify(obj).length / 2) + } catch { + logger.debug('unable to serialize object for token estimation, skipping') + return 0 + } +} diff --git a/strands-ts/src/models/openai/__tests__/chat.test.ts b/strands-ts/src/models/openai/__tests__/chat.test.ts new file mode 100644 index 0000000000..5c1325a108 --- /dev/null +++ b/strands-ts/src/models/openai/__tests__/chat.test.ts @@ -0,0 +1,1862 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import OpenAI from 'openai' +import { isNode } from '../../../__fixtures__/environment.js' +import { OpenAIModel } from '../index.js' +import { ContextWindowOverflowError, ModelThrottledError } from '../../../errors.js' +import { collectIterator } from '../../../__fixtures__/model-test-helpers.js' +import { Message, TextBlock, ToolUseBlock, ToolResultBlock, GuardContentBlock } from '../../../types/messages.js' +import type { SystemContentBlock } from '../../../types/messages.js' +import { ImageBlock, DocumentBlock, VideoBlock } from '../../../types/media.js' +import { warnOnce } from '../../../logging/warn-once.js' +import { logger } from '../../../logging/logger.js' + +/** + * Helper to create a mock OpenAI client with streaming support + */ +function createMockClient(streamGenerator: () => AsyncGenerator): OpenAI { + return { + chat: { + completions: { + create: vi.fn(async () => streamGenerator()), + }, + }, + } as any +} + +// Mock the OpenAI SDK +vi.mock('openai', () => { + const mockConstructor = vi.fn(function (this: any) { + return {} + }) + return { + default: mockConstructor, + } +}) + +vi.mock('../../../logging/warn-once.js', () => ({ + warnOnce: vi.fn(), +})) + +describe('OpenAIModel', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.restoreAllMocks() + // Set default env var for most tests using Vitest's stubEnv (Node.js only) + if (isNode) { + vi.stubEnv('OPENAI_API_KEY', 'sk-test-env') + } + }) + + afterEach(() => { + vi.clearAllMocks() + // Restore all environment variables to their original state (Node.js only) + if (isNode) { + vi.unstubAllEnvs() + } + }) + + // Shared helper to create a mock OpenAI client that captures the request + const createMockClientWithCapture = (captureContainer: { request: any }): any => { + return { + chat: { + completions: { + create: vi.fn(async (request: any) => { + captureContainer.request = request + return (async function* () { + yield { choices: [{ delta: { role: 'assistant' }, index: 0 }] } + yield { choices: [{ finish_reason: 'stop', delta: {}, index: 0 }] } + })() + }), + }, + }, + } as any + } + + describe('constructor', () => { + it('creates an instance with required modelId', () => { + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-5.4', apiKey: 'sk-test' }) + const config = provider.getConfig() + expect(config.modelId).toBe('gpt-5.4') + }) + + it('uses custom model ID', () => { + const customModelId = 'gpt-3.5-turbo' + const provider = new OpenAIModel({ api: 'chat', modelId: customModelId, apiKey: 'sk-test' }) + expect(provider.getConfig()).toStrictEqual({ + modelId: customModelId, + }) + }) + + it('warns when modelId is not explicitly set', () => { + new OpenAIModel({ api: 'chat', apiKey: 'sk-test' }) + expect(warnOnce).toHaveBeenCalledWith( + expect.objectContaining({ warn: expect.any(Function) }), + expect.stringContaining('using default modelId') + ) + }) + + it('does not warn when modelId is explicitly set', () => { + new OpenAIModel({ api: 'chat', modelId: 'gpt-5.4', apiKey: 'sk-test' }) + expect(warnOnce).not.toHaveBeenCalledWith( + expect.objectContaining({ warn: expect.any(Function) }), + expect.stringContaining('using default modelId') + ) + }) + + it('uses API key from constructor parameter', () => { + const apiKey = 'sk-explicit' + new OpenAIModel({ api: 'chat', modelId: 'gpt-5.4', apiKey }) + expect(OpenAI).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: apiKey, + }) + ) + }) + + // Node.js-specific test: environment variable usage + if (isNode) { + it('uses API key from environment variable', () => { + vi.stubEnv('OPENAI_API_KEY', 'sk-from-env') + new OpenAIModel({ api: 'chat', modelId: 'gpt-5.4' }) + // OpenAI client should be called without explicit apiKey (uses env var internally) + expect(OpenAI).toHaveBeenCalled() + }) + } + + it('explicit API key takes precedence over environment variable', () => { + if (isNode) { + vi.stubEnv('OPENAI_API_KEY', 'sk-from-env') + } + const explicitKey = 'sk-explicit' + new OpenAIModel({ api: 'chat', modelId: 'gpt-5.4', apiKey: explicitKey }) + expect(OpenAI).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: explicitKey, + }) + ) + }) + + it('throws error when no API key is available', () => { + if (isNode) { + vi.stubEnv('OPENAI_API_KEY', '') + } + expect(() => new OpenAIModel({ api: 'chat', modelId: 'gpt-5.4' })).toThrow( + "OpenAI API key is required. Provide it via the 'apiKey' option (string or function) or set the OPENAI_API_KEY environment variable." + ) + }) + + it('uses custom client configuration', () => { + const timeout = 30000 + new OpenAIModel({ api: 'chat', modelId: 'gpt-5.4', apiKey: 'sk-test', clientConfig: { timeout } }) + expect(OpenAI).toHaveBeenCalledWith( + expect.objectContaining({ + timeout: timeout, + }) + ) + }) + + it('uses provided client instance', () => { + vi.clearAllMocks() + const mockClient = {} as OpenAI + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + // Should not create a new OpenAI client + expect(OpenAI).not.toHaveBeenCalled() + expect(provider).toBeDefined() + }) + + it('provided client takes precedence over apiKey and clientConfig', () => { + vi.clearAllMocks() + const mockClient = {} as OpenAI + new OpenAIModel({ + api: 'chat', + modelId: 'gpt-5.4', + apiKey: 'sk-test', + client: mockClient, + clientConfig: { timeout: 30000 }, + }) + // Should not create a new OpenAI client when client is provided + expect(OpenAI).not.toHaveBeenCalled() + }) + + it('does not require API key when client is provided', () => { + vi.clearAllMocks() + if (isNode) { + vi.stubEnv('OPENAI_API_KEY', '') + } + const mockClient = {} as OpenAI + expect(() => new OpenAIModel({ api: 'chat', client: mockClient })).not.toThrow() + }) + + it('accepts function-based API key', () => { + const apiKeyFn = vi.fn(async () => 'sk-dynamic') + new OpenAIModel({ + api: 'chat', + modelId: 'gpt-5.4', + apiKey: apiKeyFn, + }) + expect(OpenAI).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: apiKeyFn, + }) + ) + }) + + it('accepts async function-based API key', () => { + const apiKeyFn = async (): Promise => { + await new Promise((resolve) => globalThis.setTimeout(resolve, 10)) + return 'sk-async-key' + } + + new OpenAIModel({ + api: 'chat', + modelId: 'gpt-5.4', + apiKey: apiKeyFn, + }) + + expect(OpenAI).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: apiKeyFn, + }) + ) + }) + }) + + describe('updateConfig', () => { + it('merges new config with existing config', () => { + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-5.4', apiKey: 'sk-test', temperature: 0.5 }) + provider.updateConfig({ modelId: 'gpt-5.4', temperature: 0.8, maxTokens: 2048 }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'gpt-5.4', + temperature: 0.8, + maxTokens: 2048, + contextWindowLimit: 1_050_000, + }) + }) + + it('preserves fields not included in the update', () => { + const provider = new OpenAIModel({ + api: 'chat', + apiKey: 'sk-test', + modelId: 'gpt-3.5-turbo', + temperature: 0.5, + maxTokens: 1024, + }) + provider.updateConfig({ modelId: 'gpt-3.5-turbo', temperature: 0.8 }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'gpt-3.5-turbo', + temperature: 0.8, + maxTokens: 1024, + }) + }) + + it('re-resolves contextWindowLimit when modelId changes and it was auto-resolved', () => { + const provider = new OpenAIModel({ api: 'chat', apiKey: 'sk-test' }) + expect(provider.getConfig().contextWindowLimit).toBe(1_050_000) // gpt-5.4 default + + provider.updateConfig({ modelId: 'gpt-4o' }) + expect(provider.getConfig().contextWindowLimit).toBe(128_000) // gpt-4o value + }) + + it('preserves explicit contextWindowLimit when modelId changes', () => { + const provider = new OpenAIModel({ api: 'chat', apiKey: 'sk-test', contextWindowLimit: 50_000 }) + expect(provider.getConfig().contextWindowLimit).toBe(50_000) + + provider.updateConfig({ modelId: 'gpt-4o' }) + expect(provider.getConfig().contextWindowLimit).toBe(50_000) // preserved + }) + }) + + describe('getConfig', () => { + it('returns the current configuration', () => { + const provider = new OpenAIModel({ + api: 'chat', + modelId: 'gpt-5.4', + apiKey: 'sk-test', + maxTokens: 1024, + temperature: 0.7, + }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'gpt-5.4', + maxTokens: 1024, + temperature: 0.7, + contextWindowLimit: 1_050_000, + }) + }) + + it('includes contextWindowLimit in config when provided', () => { + const provider = new OpenAIModel({ + api: 'chat', + modelId: 'gpt-4o', + apiKey: 'sk-test', + contextWindowLimit: 128_000, + }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'gpt-4o', + contextWindowLimit: 128_000, + }) + }) + + it('auto-populates contextWindowLimit from model ID lookup', () => { + const provider = new OpenAIModel({ api: 'chat', modelId: 'gpt-4o', apiKey: 'sk-test' }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'gpt-4o', + contextWindowLimit: 128_000, + }) + }) + + it('auto-populates contextWindowLimit for default model ID', () => { + const provider = new OpenAIModel({ api: 'chat', apiKey: 'sk-test' }) + expect(provider.getConfig()).toStrictEqual({ + contextWindowLimit: 1_050_000, + }) + }) + + it('does not override explicit contextWindowLimit', () => { + const provider = new OpenAIModel({ + api: 'chat', + modelId: 'gpt-4o', + apiKey: 'sk-test', + contextWindowLimit: 50_000, + }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'gpt-4o', + contextWindowLimit: 50_000, + }) + }) + + it('leaves contextWindowLimit undefined for unknown model IDs', () => { + const provider = new OpenAIModel({ api: 'chat', modelId: 'unknown-model', apiKey: 'sk-test' }) + expect(provider.getConfig()).toStrictEqual({ + modelId: 'unknown-model', + }) + }) + }) + + describe('managed params warning', () => { + it('warns on construction when params contains provider-managed keys', () => { + const warnSpy = vi.spyOn(logger, 'warn') + new OpenAIModel({ + api: 'chat', + client: {} as OpenAI, + params: { model: 'bad', stream: false }, + }) + expect(warnSpy).toHaveBeenCalledTimes(2) + expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining("'model'")) + expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining("'stream'")) + warnSpy.mockRestore() + }) + + it('warns on updateConfig when params contains provider-managed keys', () => { + const model = new OpenAIModel({ api: 'chat', client: {} as OpenAI }) + const warnSpy = vi.spyOn(logger, 'warn') + model.updateConfig({ params: { stream_options: { include_usage: false } } }) + expect(warnSpy).toHaveBeenCalledTimes(1) + expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining("'stream_options'")) + warnSpy.mockRestore() + }) + + it('does not warn when params contains only non-managed keys', () => { + const warnSpy = vi.spyOn(logger, 'warn') + new OpenAIModel({ api: 'chat', client: {} as OpenAI, params: { seed: 42 } }) + expect(warnSpy).not.toHaveBeenCalled() + warnSpy.mockRestore() + }) + + it('provider-managed fields in params are overridden and cannot take effect', async () => { + const captured: { request: any } = { request: null } + const mockClient = createMockClientWithCapture(captured) + const warnSpy = vi.spyOn(logger, 'warn') + const provider = new OpenAIModel({ + api: 'chat', + modelId: 'gpt-5.4', + client: mockClient, + params: { + model: 'attacker-model', + messages: [{ role: 'user', content: 'hijacked' }], + stream: false, + stream_options: { include_usage: false }, + }, + }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + await collectIterator(provider.stream(messages)) + expect(captured.request.model).toBe('gpt-5.4') + expect(captured.request.stream).toBe(true) + expect(captured.request.stream_options).toEqual({ include_usage: true }) + expect(Array.isArray(captured.request.messages)).toBe(true) + expect(captured.request.messages[0]).toEqual({ role: 'user', content: [{ type: 'text', text: 'Hi' }] }) + warnSpy.mockRestore() + }) + }) + + describe('stream', () => { + describe('validation', () => { + it('throws error when messages array is empty', async () => { + const mockClient = createMockClient(async function* () {}) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + + await expect(async () => { + await collectIterator(provider.stream([])) + }).rejects.toThrow('At least one message is required') + }) + + it('validates system prompt is not empty', async () => { + const mockClient = createMockClient(async function* () { + yield { + choices: [{ delta: { role: 'assistant', content: 'Hello' }, index: 0 }], + } + yield { + choices: [{ finish_reason: 'stop', delta: {}, index: 0 }], + } + }) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + // System prompt that's only whitespace should not be sent + const events = await collectIterator(provider.stream(messages, { systemPrompt: ' ' })) + + // Should still get valid events + expect(events.length).toBeGreaterThan(0) + expect(events[0]?.type).toBe('modelMessageStartEvent') + }) + + it('throws error for streaming with n > 1', async () => { + const mockClient = createMockClient(async function* () {}) + const provider = new OpenAIModel({ + api: 'chat', + modelId: 'gpt-5.4', + client: mockClient, + params: { n: 2 }, + }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(async () => { + for await (const _ of provider.stream(messages)) { + // Should not reach here + } + }).rejects.toThrow('Streaming with n > 1 is not supported') + }) + + it('throws error for tool spec without name or description', async () => { + const mockClient = createMockClient(async function* () {}) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(async () => { + for await (const _ of provider.stream(messages, { + toolSpecs: [{ name: '', description: 'test', inputSchema: {} }], + })) { + // Should not reach here + } + }).rejects.toThrow('Tool specification must have both name and description') + }) + + it('throws error for empty tool result content', async () => { + const mockClient = createMockClient(async function* () {}) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-123', + status: 'success', + content: [], + }), + ], + }), + ] + + await expect(async () => { + for await (const _ of provider.stream(messages)) { + // Should not reach here + } + }).rejects.toThrow('Tool result for toolUseId "tool-123" has empty content') + }) + + it('handles tool result with error status', async () => { + const mockClient = createMockClient(async function* () { + yield { + choices: [{ delta: { role: 'assistant', content: 'Ok' }, index: 0 }], + } + yield { + choices: [{ finish_reason: 'stop', delta: {}, index: 0 }], + } + }) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Run tool')] }), + new Message({ + role: 'assistant', + content: [ + new ToolUseBlock({ + name: 'calculator', + toolUseId: 'tool-123', + input: { expr: 'invalid' }, + }), + ], + }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-123', + status: 'error', + content: [new TextBlock('Division by zero')], + }), + ], + }), + ] + + // Should not throw - error status is handled by prepending [ERROR] + const events = await collectIterator(provider.stream(messages)) + + // Verify we got a response + expect(events.length).toBeGreaterThan(0) + expect(events[0]?.type).toBe('modelMessageStartEvent') + }) + + it('throws error for circular reference in tool input', async () => { + const mockClient = createMockClient(async function* () {}) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + + const circular: any = { a: 1 } + circular.self = circular + + const messages = [ + new Message({ role: 'user', content: [new TextBlock('Hi')] }), + new Message({ + role: 'assistant', + content: [ + new ToolUseBlock({ + name: 'test', + toolUseId: 'tool-1', + input: circular, + }), + ], + }), + ] + + await expect(async () => { + for await (const _ of provider.stream(messages)) { + // Should not reach here + } + }).rejects.toThrow('Failed to serialize tool input') + }) + }) + + describe('basic streaming', () => { + it('yields correct event sequence for simple text response', async () => { + const mockClient = createMockClient(async function* () { + yield { + choices: [{ delta: { role: 'assistant' }, index: 0 }], + } + yield { + choices: [{ delta: { content: 'Hello' }, index: 0 }], + } + yield { + choices: [{ delta: { content: ' world' }, index: 0 }], + } + yield { + choices: [{ finish_reason: 'stop', delta: {}, index: 0 }], + } + }) + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const events = await collectIterator(provider.stream(messages)) + + // Now includes complete content block lifecycle: start, deltas, stop + expect(events).toHaveLength(6) + expect(events[0]).toEqual({ type: 'modelMessageStartEvent', role: 'assistant' }) + expect(events[1]).toEqual({ + type: 'modelContentBlockStartEvent', + }) + expect(events[2]).toEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: 'Hello' }, + }) + expect(events[3]).toEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: ' world' }, + }) + expect(events[4]).toEqual({ + type: 'modelContentBlockStopEvent', + }) + expect(events[5]).toEqual({ type: 'modelMessageStopEvent', stopReason: 'endTurn' }) + }) + }) + + it('emits modelMetadataEvent with usage information', async () => { + const mockClient = createMockClient(async function* () { + yield { + choices: [{ delta: { role: 'assistant' }, index: 0 }], + } + yield { + choices: [{ finish_reason: 'stop', delta: {}, index: 0 }], + } + yield { + choices: [], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, + } + }) + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const events = await collectIterator(provider.stream(messages)) + + const metadataEvent = events.find((e) => e.type === 'modelMetadataEvent') + expect(metadataEvent).toBeDefined() + expect(metadataEvent).toEqual({ + type: 'modelMetadataEvent', + usage: { + inputTokens: 10, + outputTokens: 5, + totalTokens: 15, + }, + }) + }) + + it('handles usage with undefined properties', async () => { + const mockClient = createMockClient(async function* () { + yield { + choices: [{ delta: { role: 'assistant' }, index: 0 }], + } + yield { + choices: [{ finish_reason: 'stop', delta: {}, index: 0 }], + } + yield { + choices: [], + usage: {}, // Empty usage object + } + }) + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const events = await collectIterator(provider.stream(messages)) + + const metadataEvent = events.find((e) => e.type === 'modelMetadataEvent') + expect(metadataEvent).toBeDefined() + expect(metadataEvent).toEqual({ + type: 'modelMetadataEvent', + usage: { + inputTokens: 0, + outputTokens: 0, + totalTokens: 0, + }, + }) + }) + + it('filters out empty string content deltas', async () => { + const mockClient = createMockClient(async function* () { + yield { + choices: [{ delta: { role: 'assistant' }, index: 0 }], + } + yield { + choices: [{ delta: { content: '' }, index: 0 }], // Empty content + } + yield { + choices: [{ delta: { content: 'Hello' }, index: 0 }], + } + yield { + choices: [{ finish_reason: 'stop', delta: {}, index: 0 }], + } + }) + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const events = await collectIterator(provider.stream(messages)) + + // Should not emit event for empty content + const contentEvents = events.filter((e) => e.type === 'modelContentBlockDeltaEvent') + expect(contentEvents).toHaveLength(1) + expect((contentEvents[0] as any).delta.text).toBe('Hello') + }) + + it('prevents duplicate message start events', async () => { + const mockClient = createMockClient(async function* () { + yield { + choices: [{ delta: { role: 'assistant' }, index: 0 }], + } + yield { + choices: [{ delta: { role: 'assistant', content: 'Hello' }, index: 0 }], // Duplicate role + } + yield { + choices: [{ finish_reason: 'stop', delta: {}, index: 0 }], + } + }) + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + // Suppress console.warn for this test + vi.spyOn(console, 'warn').mockImplementation(() => {}) + + const events = await collectIterator(provider.stream(messages)) + + // Should only have one message start event + const startEvents = events.filter((e) => e.type === 'modelMessageStartEvent') + expect(startEvents).toHaveLength(1) + }) + }) + + describe('tool calling', () => { + it('handles tool use request with contentBlockStart and contentBlockStop events', async () => { + const mockClient = createMockClient(async function* () { + yield { + choices: [{ delta: { role: 'assistant' }, index: 0 }], + } + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: 'call_123', + type: 'function', + function: { name: 'calculator', arguments: '' }, + }, + ], + }, + index: 0, + }, + ], + } + yield { + choices: [ + { + delta: { + tool_calls: [{ index: 0, function: { arguments: '{"expr' } }], + }, + index: 0, + }, + ], + } + yield { + choices: [ + { + delta: { + tool_calls: [{ index: 0, function: { arguments: '":"2+2"}' } }], + }, + index: 0, + }, + ], + } + yield { + choices: [{ finish_reason: 'tool_calls', delta: {}, index: 0 }], + } + }) + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Calculate 2+2')] })] + + const events = await collectIterator(provider.stream(messages)) + + // Verify key events in sequence + expect(events[0]).toEqual({ type: 'modelMessageStartEvent', role: 'assistant' }) + expect(events[1]).toEqual({ + type: 'modelContentBlockStartEvent', + start: { + type: 'toolUseStart', + name: 'calculator', + toolUseId: 'call_123', + }, + }) + expect(events[2]).toEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { + type: 'toolUseInputDelta', + input: '{"expr', + }, + }) + expect(events[3]).toEqual({ + type: 'modelContentBlockDeltaEvent', + delta: { + type: 'toolUseInputDelta', + input: '":"2+2"}', + }, + }) + expect(events[4]).toEqual({ + type: 'modelContentBlockStopEvent', + }) + expect(events[5]).toEqual({ type: 'modelMessageStopEvent', stopReason: 'toolUse' }) + }) + + it('handles multiple tool calls', async () => { + const mockClient = createMockClient(async function* () { + yield { + choices: [{ delta: { role: 'assistant' }, index: 0 }], + } + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: 'call_1', + type: 'function', + function: { name: 'tool1', arguments: '{}' }, + }, + ], + }, + index: 0, + }, + ], + } + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: 1, + id: 'call_2', + type: 'function', + function: { name: 'tool2', arguments: '{}' }, + }, + ], + }, + index: 0, + }, + ], + } + yield { + choices: [{ finish_reason: 'tool_calls', delta: {}, index: 0 }], + } + }) + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const events = await collectIterator(provider.stream(messages)) + + // Should emit stop events for both tool calls + const stopEvents = events.filter((e) => e.type === 'modelContentBlockStopEvent') + expect(stopEvents).toHaveLength(2) + expect(stopEvents[0]).toEqual({ type: 'modelContentBlockStopEvent' }) + expect(stopEvents[1]).toEqual({ type: 'modelContentBlockStopEvent' }) + }) + + it('skips tool calls with invalid index', async () => { + const mockClient = createMockClient(async function* () { + yield { + choices: [{ delta: { role: 'assistant' }, index: 0 }], + } + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: undefined as any, // Invalid index + id: 'call_123', + type: 'function', + function: { name: 'tool', arguments: '{}' }, + }, + ], + }, + index: 0, + }, + ], + } + yield { + choices: [{ finish_reason: 'stop', delta: {}, index: 0 }], + } + }) + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + // Suppress console.warn for this test + vi.spyOn(console, 'warn').mockImplementation(() => {}) + + const events = await collectIterator(provider.stream(messages)) + + // Should not emit any tool-related events + const toolEvents = events.filter( + (e) => e.type === 'modelContentBlockStartEvent' || e.type === 'modelContentBlockDeltaEvent' + ) + expect(toolEvents).toHaveLength(0) + + // The important thing is that invalid tool calls don't crash the stream + // and are properly skipped + expect(events.length).toBeGreaterThan(0) // Still got message events + }) + + it('tool argument deltas can be reassembled into valid JSON', async () => { + const mockClient = createMockClient(async function* () { + yield { choices: [{ delta: { role: 'assistant' }, index: 0 }] } + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: 'call_123', + type: 'function', + function: { name: 'calculator', arguments: '' }, + }, + ], + }, + index: 0, + }, + ], + } + // Split JSON across multiple chunks in realistic ways + yield { choices: [{ delta: { tool_calls: [{ index: 0, function: { arguments: '{"' } }] }, index: 0 }] } + yield { choices: [{ delta: { tool_calls: [{ index: 0, function: { arguments: 'x":' } }] }, index: 0 }] } + yield { choices: [{ delta: { tool_calls: [{ index: 0, function: { arguments: '10,' } }] }, index: 0 }] } + yield { choices: [{ delta: { tool_calls: [{ index: 0, function: { arguments: '"y":' } }] }, index: 0 }] } + yield { choices: [{ delta: { tool_calls: [{ index: 0, function: { arguments: '20}' } }] }, index: 0 }] } + yield { choices: [{ finish_reason: 'tool_calls', delta: {}, index: 0 }] } + }) + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const events = await collectIterator(provider.stream(messages)) + + // Extract and concatenate all tool input deltas + const inputDeltas = events + .filter((e) => e.type === 'modelContentBlockDeltaEvent' && (e as any).delta.type === 'toolUseInputDelta') + .map((e) => (e as any).delta.input) + + const reassembled = inputDeltas.join('') + + // Should be valid JSON + expect(() => JSON.parse(reassembled)).not.toThrow() + expect(JSON.parse(reassembled)).toEqual({ x: 10, y: 20 }) + }) + + it('handles messages with both text and tool calls', async () => { + const mockClient = createMockClient(async function* () { + yield { choices: [{ delta: { role: 'assistant' }, index: 0 }] } + // Text content first + yield { choices: [{ delta: { content: 'Let me calculate ' }, index: 0 }] } + yield { choices: [{ delta: { content: 'that for you.' }, index: 0 }] } + // Then tool call + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: 'call_123', + type: 'function', + function: { name: 'calculator', arguments: '{"expr":"2+2"}' }, + }, + ], + }, + index: 0, + }, + ], + } + yield { choices: [{ finish_reason: 'tool_calls', delta: {}, index: 0 }] } + }) + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Calculate 2+2')] })] + + const events = await collectIterator(provider.stream(messages)) + + // Should have text deltas followed by tool events + expect(events[0]?.type).toBe('modelMessageStartEvent') + // Text content block start + expect(events[1]?.type).toBe('modelContentBlockStartEvent') + // Text deltas + expect(events[2]?.type).toBe('modelContentBlockDeltaEvent') + expect((events[2] as any).delta.type).toBe('textDelta') + expect((events[2] as any).delta.text).toBe('Let me calculate ') + // Tool events should follow + const toolStartEvent = events.find( + (e) => e.type === 'modelContentBlockStartEvent' && (e as any).start?.type === 'toolUseStart' + ) + expect(toolStartEvent).toBeDefined() + // Both text and tool blocks should have stop events + const stopEvents = events.filter((e) => e.type === 'modelContentBlockStopEvent') + expect(stopEvents.length).toBeGreaterThan(0) + }) + }) + + describe('stop reasons', () => { + it('maps OpenAI stop reasons to SDK stop reasons', async () => { + const stopReasons = [ + { openai: 'stop', sdk: 'endTurn' }, + { openai: 'tool_calls', sdk: 'toolUse' }, + { openai: 'length', sdk: 'maxTokens' }, + { openai: 'content_filter', sdk: 'contentFiltered' }, + ] + + for (const { openai, sdk } of stopReasons) { + const mockClient = createMockClient(async function* () { + yield { + choices: [{ delta: { role: 'assistant' }, index: 0 }], + } + yield { + choices: [{ finish_reason: openai, delta: {}, index: 0 }], + } + }) + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const events = await collectIterator(provider.stream(messages)) + + const stopEvent = events.find((e) => e.type === 'modelMessageStopEvent') + expect(stopEvent).toBeDefined() + expect((stopEvent as any).stopReason).toBe(sdk) + } + }) + + it('handles unknown stop reasons with warning', async () => { + const mockClient = createMockClient(async function* () { + yield { + choices: [{ delta: { role: 'assistant' }, index: 0 }], + } + yield { + choices: [{ finish_reason: 'new_unknown_reason', delta: {}, index: 0 }], + } + }) + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const events = await collectIterator(provider.stream(messages)) + + // Should convert unknown stop reason to camelCase + const stopEvent = events.find((e) => e.type === 'modelMessageStopEvent') + expect(stopEvent).toBeDefined() + expect((stopEvent as any).stopReason).toBe('newUnknownReason') + + // Note: Warning logging is verified manually/visually since console.warn spying + // has test isolation issues when running the full test suite + }) + }) + + describe('API request formatting', () => { + it('formats API request correctly with all options', async () => { + let capturedRequest: any = null + let callCount = 0 + + const mockClient = { + chat: { + completions: { + create: vi.fn(async (request: any) => { + capturedRequest = request + callCount++ + // Return an async generator + return (async function* (): AsyncGenerator { + yield { choices: [{ delta: { role: 'assistant' }, index: 0 }] } + yield { choices: [{ finish_reason: 'stop', delta: {}, index: 0 }] } + })() + }), + }, + }, + } as any + + const provider = new OpenAIModel({ + api: 'chat', + modelId: 'gpt-5.4', + client: mockClient, + temperature: 0.7, + maxTokens: 1000, + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + const toolSpecs = [ + { + name: 'calculator', + description: 'Calculate expressions', + inputSchema: { type: 'object' as const, properties: { expr: { type: 'string' as const } } }, + }, + ] + + await collectIterator( + provider.stream(messages, { + systemPrompt: 'You are a helpful assistant', + toolSpecs, + toolChoice: { auto: {} }, + }) + ) + + // Verify create was called with correct structure + expect(callCount).toBe(1) + expect(capturedRequest).toBeDefined() + expect(capturedRequest).toEqual({ + model: 'gpt-5.4', + stream: true, + stream_options: { include_usage: true }, + temperature: 0.7, + max_completion_tokens: 1000, + messages: [ + { role: 'system', content: 'You are a helpful assistant' }, + { role: 'user', content: [{ type: 'text', text: 'Hi' }] }, + ], + tools: [ + { + type: 'function', + function: { + name: 'calculator', + description: 'Calculate expressions', + parameters: { type: 'object', properties: { expr: { type: 'string' } } }, + }, + }, + ], + tool_choice: 'auto', + }) + }) + }) + + describe('systemPrompt handling', () => { + it('formats array system prompt with text blocks only', async () => { + const captured: { request: any } = { request: null } + const mockClient = createMockClientWithCapture(captured) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + await collectIterator( + provider.stream(messages, { + systemPrompt: [ + { type: 'textBlock', text: 'You are a helpful assistant' }, + { type: 'textBlock', text: 'Additional context here' }, + ] as SystemContentBlock[], + }) + ) + + expect(captured.request).toBeDefined() + expect(captured.request!.messages).toEqual([ + { role: 'system', content: 'You are a helpful assistantAdditional context here' }, + { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, + ]) + }) + + it('formats array system prompt with cache points', async () => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const captured: { request: any } = { request: null } + const mockClient = createMockClientWithCapture(captured) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + collectIterator( + provider.stream(messages, { + systemPrompt: [ + { type: 'textBlock', text: 'You are a helpful assistant' }, + { type: 'textBlock', text: 'Large context document' }, + { type: 'cachePointBlock', cacheType: 'default' }, + ] as SystemContentBlock[], + }) + ) + + // Verify warning was logged + expect(warnSpy).toHaveBeenCalledWith( + 'cache points are not supported in openai system prompts, ignoring cache points' + ) + + // Verify system message contains only text (cache points ignored) + expect(captured.request).toBeDefined() + expect(captured.request!.messages).toEqual([ + { role: 'system', content: 'You are a helpful assistantLarge context document' }, + { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, + ]) + + warnSpy.mockRestore() + }) + + it('handles empty array system prompt', async () => { + const captured: { request: any } = { request: null } + const mockClient = createMockClientWithCapture(captured) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + await collectIterator( + provider.stream(messages, { + systemPrompt: [], + }) + ) + + // Empty array should not add system message + expect(captured.request).toBeDefined() + expect(captured.request!.messages).toEqual([{ role: 'user', content: [{ type: 'text', text: 'Hello' }] }]) + }) + + it('formats array system prompt with single text block', async () => { + const captured: { request: any } = { request: null } + const mockClient = createMockClientWithCapture(captured) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + await collectIterator( + provider.stream(messages, { + systemPrompt: [{ type: 'textBlock', text: 'You are a helpful assistant' }] as SystemContentBlock[], + }) + ) + + expect(captured.request).toBeDefined() + expect(captured.request!.messages).toEqual([ + { role: 'system', content: 'You are a helpful assistant' }, + { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, + ]) + }) + + it('warns and filters guard content from system prompt', async () => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const captured: { request: any } = { request: null } + const mockClient = createMockClientWithCapture(captured) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + await collectIterator( + provider.stream(messages, { + systemPrompt: [ + { type: 'textBlock', text: 'You are a helpful assistant' }, + { + type: 'guardContentBlock', + text: { + qualifiers: ['grounding_source'], + text: 'Guard content', + }, + }, + ] as SystemContentBlock[], + }) + ) + + // Verify warning was logged + expect(warnSpy).toHaveBeenCalledWith( + 'guard content is not supported in openai system prompts, removing guard content block' + ) + + // Verify guard content is filtered out + expect(captured.request).toBeDefined() + expect(captured.request!.messages).toEqual([ + { role: 'system', content: 'You are a helpful assistant' }, + { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, + ]) + + warnSpy.mockRestore() + }) + + it('preserves text blocks when filtering guard content', async () => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const captured: { request: any } = { request: null } + const mockClient = createMockClientWithCapture(captured) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + await collectIterator( + provider.stream(messages, { + systemPrompt: [ + { type: 'textBlock', text: 'First text' }, + { + type: 'guardContentBlock', + text: { + qualifiers: ['query'], + text: 'Guard content', + }, + }, + { type: 'textBlock', text: 'Second text' }, + ] as SystemContentBlock[], + }) + ) + + // Verify warning was logged + expect(warnSpy).toHaveBeenCalledWith( + 'guard content is not supported in openai system prompts, removing guard content block' + ) + + // Verify both text blocks preserved, guard content removed + expect(captured.request).toBeDefined() + expect(captured.request!.messages).toEqual([ + { role: 'system', content: 'First textSecond text' }, + { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, + ]) + + warnSpy.mockRestore() + }) + + it('handles system prompt with only guard content', async () => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const captured: { request: any } = { request: null } + const mockClient = createMockClientWithCapture(captured) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + + await collectIterator( + provider.stream(messages, { + systemPrompt: [ + { + type: 'guardContentBlock', + text: { + qualifiers: ['guard_content'], + text: 'Only guard content', + }, + }, + ] as SystemContentBlock[], + }) + ) + + // Verify warning was logged + expect(warnSpy).toHaveBeenCalledWith( + 'guard content is not supported in openai system prompts, removing guard content block' + ) + + // Verify no system message added (only guard content) + expect(captured.request).toBeDefined() + expect(captured.request!.messages).toEqual([{ role: 'user', content: [{ type: 'text', text: 'Hello' }] }]) + + warnSpy.mockRestore() + }) + }) + + describe('guard content in messages', () => { + it('warns and filters guard content from user messages', async () => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const captured: { request: any } = { request: null } + const mockClient = createMockClientWithCapture(captured) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [ + new Message({ + role: 'user', + content: [ + new TextBlock('Verify this:'), + new GuardContentBlock({ + text: { + qualifiers: ['grounding_source'], + text: 'Guard content', + }, + }), + new TextBlock('Is it correct?'), + ], + }), + ] + + await collectIterator(provider.stream(messages)) + + // Verify warning was logged + expect(warnSpy).toHaveBeenCalledWith( + 'block_type= | unsupported content type in openai user message | skipping' + ) + + // Verify guard content filtered out + expect(captured.request).toBeDefined() + expect(captured.request!.messages).toEqual([ + { + role: 'user', + content: [ + { type: 'text', text: 'Verify this:' }, + { type: 'text', text: 'Is it correct?' }, + ], + }, + ]) + + warnSpy.mockRestore() + }) + + it('warns and filters guard content with image from user messages', async () => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const captured: { request: any } = { request: null } + const mockClient = createMockClientWithCapture(captured) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const imageBytes = new Uint8Array([1, 2, 3, 4]) + const messages = [ + new Message({ + role: 'user', + content: [ + new TextBlock('Check this image:'), + new GuardContentBlock({ + image: { + format: 'jpeg', + source: { bytes: imageBytes }, + }, + }), + ], + }), + ] + + await collectIterator(provider.stream(messages)) + + // Verify warning was logged + expect(warnSpy).toHaveBeenCalledWith( + 'block_type= | unsupported content type in openai user message | skipping' + ) + + // Verify guard content filtered out + expect(captured.request).toBeDefined() + expect(captured.request!.messages).toEqual([ + { role: 'user', content: [{ type: 'text', text: 'Check this image:' }] }, + ]) + + warnSpy.mockRestore() + }) + + it('handles message with only guard content', async () => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const captured: { request: any } = { request: null } + const mockClient = createMockClientWithCapture(captured) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [ + new Message({ + role: 'user', + content: [ + new GuardContentBlock({ + text: { + qualifiers: ['guard_content'], + text: 'Only guard content', + }, + }), + ], + }), + ] + + await collectIterator(provider.stream(messages)) + + // Verify warning was logged + expect(warnSpy).toHaveBeenCalledWith( + 'block_type= | unsupported content type in openai user message | skipping' + ) + + // Verify no user message added (only guard content) + expect(captured.request).toBeDefined() + expect(captured.request!.messages).toEqual([]) + + warnSpy.mockRestore() + }) + }) + + describe('media blocks', () => { + it('formats image block in user message as image_url with base64', async () => { + const captured: { request: any } = { request: null } + const mockClient = createMockClientWithCapture(captured) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const imageBytes = new Uint8Array([72, 101, 108, 108, 111]) + const messages = [ + new Message({ + role: 'user', + content: [ + new TextBlock('What is in this image?'), + new ImageBlock({ format: 'png', source: { bytes: imageBytes } }), + ], + }), + ] + + await collectIterator(provider.stream(messages)) + + const userMsg = captured.request.messages[0] + expect(userMsg.role).toBe('user') + expect(userMsg.content).toHaveLength(2) + expect(userMsg.content[0]).toEqual({ type: 'text', text: 'What is in this image?' }) + expect(userMsg.content[1]).toEqual({ + type: 'image_url', + image_url: { url: 'data:image/png;base64,SGVsbG8=' }, + }) + }) + + it('formats image block in user message with URL source', async () => { + const captured: { request: any } = { request: null } + const mockClient = createMockClientWithCapture(captured) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [ + new Message({ + role: 'user', + content: [new ImageBlock({ format: 'jpeg', source: { url: 'https://example.com/img.jpg' } })], + }), + ] + + await collectIterator(provider.stream(messages)) + + const userMsg = captured.request.messages[0] + expect(userMsg.content[0]).toEqual({ + type: 'image_url', + image_url: { url: 'https://example.com/img.jpg' }, + }) + }) + + it('formats document block with bytes source as file in user message', async () => { + const captured: { request: any } = { request: null } + const mockClient = createMockClientWithCapture(captured) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const docBytes = new Uint8Array([1, 2, 3]) + const messages = [ + new Message({ + role: 'user', + content: [new DocumentBlock({ name: 'report.pdf', format: 'pdf', source: { bytes: docBytes } })], + }), + ] + + await collectIterator(provider.stream(messages)) + + const userMsg = captured.request.messages[0] + expect(userMsg.content[0]).toEqual({ + type: 'file', + file: { file_data: 'data:application/pdf;base64,AQID', filename: 'report.pdf' }, + }) + }) + + it('splits image from tool result into separate user message', async () => { + const captured: { request: any } = { request: null } + const mockClient = createMockClientWithCapture(captured) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const imageBytes = new Uint8Array([72, 101, 108, 108, 111]) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [ + new TextBlock('Screenshot captured'), + new ImageBlock({ format: 'png', source: { bytes: imageBytes } }), + ], + }), + ], + }), + ] + + await collectIterator(provider.stream(messages)) + + // Tool message with text only + const toolMsg = captured.request.messages[0] + expect(toolMsg.role).toBe('tool') + expect(toolMsg.tool_call_id).toBe('tool-1') + expect(toolMsg.content).toBe('Screenshot captured') + + // Separate user message with image + const userMsg = captured.request.messages[1] + expect(userMsg.role).toBe('user') + expect(userMsg.content[0]).toEqual({ + type: 'image_url', + image_url: { url: 'data:image/png;base64,SGVsbG8=' }, + }) + }) + + it('injects placeholder text when tool result contains only images', async () => { + const captured: { request: any } = { request: null } + const mockClient = createMockClientWithCapture(captured) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new ImageBlock({ format: 'png', source: { bytes: new Uint8Array([1]) } })], + }), + ], + }), + ] + + await collectIterator(provider.stream(messages)) + + const toolMsg = captured.request.messages[0] + expect(toolMsg.content).toContain('Tool successfully returned an image') + }) + + it('skips document block in tool result with warning', async () => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const captured: { request: any } = { request: null } + const mockClient = createMockClientWithCapture(captured) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [ + new TextBlock('result'), + new DocumentBlock({ name: 'doc.pdf', format: 'pdf', source: { bytes: new Uint8Array([1]) } }), + ], + }), + ], + }), + ] + + await collectIterator(provider.stream(messages)) + + const toolMsg = captured.request.messages[0] + expect(toolMsg.content).toBe('result') + expect(warnSpy).toHaveBeenCalled() + warnSpy.mockRestore() + }) + + it('skips video block in tool result with warning', async () => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + const captured: { request: any } = { request: null } + const mockClient = createMockClientWithCapture(captured) + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [ + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [ + new TextBlock('result'), + new VideoBlock({ format: 'mp4', source: { bytes: new Uint8Array([1]) } }), + ], + }), + ], + }), + ] + + await collectIterator(provider.stream(messages)) + + const toolMsg = captured.request.messages[0] + expect(toolMsg.content).toBe('result') + expect(warnSpy).toHaveBeenCalled() + warnSpy.mockRestore() + }) + }) + + describe('error handling', () => { + it('throws ContextWindowOverflowError for structured error with code', async () => { + const mockClient = { + chat: { + completions: { + create: vi.fn(async () => { + const error: any = new Error('Context length exceeded') + error.code = 'context_length_exceeded' + throw error + }), + }, + }, + } as any + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(async () => { + for await (const _ of provider.stream(messages)) { + // Should not reach here + } + }).rejects.toThrow(ContextWindowOverflowError) + }) + + it.each([ + 'maximum context length exceeded', + 'context_length_exceeded', + 'too many tokens', + 'context length', + 'Input is too long for requested model', + 'input length and `max_tokens` exceed context limit', + 'too many total text bytes', + ])('throws ContextWindowOverflowError for error message pattern "%s"', async (message) => { + const mockClient = { + chat: { + completions: { + create: vi.fn(async () => { + throw new Error(message) + }), + }, + }, + } as any + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(async () => { + for await (const _ of provider.stream(messages)) { + // Should not reach here + } + }).rejects.toThrow(ContextWindowOverflowError) + }) + + it('throws ContextWindowOverflowError for APIError instance', async () => { + const mockClient = { + chat: { + completions: { + create: vi.fn(async () => { + // Simulate APIError from openai package + const error: any = new Error('Context length exceeded') + error.name = 'APIError' + error.status = 400 + error.code = 'context_length_exceeded' + // Make it behave like an APIError instance + Object.setPrototypeOf(error, Error.prototype) + throw error + }), + }, + }, + } as any + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(async () => { + for await (const _ of provider.stream(messages)) { + // Should not reach here + } + }).rejects.toThrow(ContextWindowOverflowError) + }) + + it('passes through other errors unchanged', async () => { + const mockClient = { + chat: { + completions: { + create: vi.fn(async () => { + throw new Error('Invalid API key') + }), + }, + }, + } as any + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(async () => { + for await (const _ of provider.stream(messages)) { + // Should not reach here + } + }).rejects.toThrow('Invalid API key') + }) + + it('handles stream interruption errors', async () => { + const mockClient = createMockClient(async function* () { + yield { choices: [{ delta: { role: 'assistant' }, index: 0 }] } + yield { choices: [{ delta: { content: 'Hello' }, index: 0 }] } + // Stream interruption + throw new Error('Network connection lost') + }) + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(async () => { + for await (const _ of provider.stream(messages)) { + // Stream will be interrupted + } + }).rejects.toThrow('Network connection lost') + }) + + it('throws ModelThrottledError for HTTP 429 status', async () => { + const originalError: Error & { status?: number } = new Error('Too many requests') + originalError.status = 429 + + const mockClient = { + chat: { + completions: { + create: vi.fn(async () => { + throw originalError + }), + }, + }, + } as unknown as OpenAI + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(async () => { + for await (const _ of provider.stream(messages)) { + // Should not reach here + } + }).rejects.toThrow(ModelThrottledError) + }) + + it('throws ModelThrottledError for rate_limit_exceeded error code', async () => { + const originalError: Error & { code?: string } = new Error('Rate limit reached') + originalError.code = 'rate_limit_exceeded' + + const mockClient = { + chat: { + completions: { + create: vi.fn(async () => { + throw originalError + }), + }, + }, + } as unknown as OpenAI + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(async () => { + for await (const _ of provider.stream(messages)) { + // Should not reach here + } + }).rejects.toThrow(ModelThrottledError) + }) + + it('throws ModelThrottledError for error message containing rate limit pattern', async () => { + const mockClient = { + chat: { + completions: { + create: vi.fn(async () => { + throw new Error('You have exceeded your rate limit') + }), + }, + }, + } as unknown as OpenAI + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(async () => { + for await (const _ of provider.stream(messages)) { + // Should not reach here + } + }).rejects.toThrow(ModelThrottledError) + }) + + it('throws ModelThrottledError for too many requests message', async () => { + const mockClient = { + chat: { + completions: { + create: vi.fn(async () => { + throw new Error('Too many requests, please slow down') + }), + }, + }, + } as unknown as OpenAI + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + await expect(async () => { + for await (const _ of provider.stream(messages)) { + // Should not reach here + } + }).rejects.toThrow(ModelThrottledError) + }) + + it('preserves original error as cause in ModelThrottledError', async () => { + const originalError: Error & { status?: number } = new Error('Request too large for gpt-5.4 on tokens per min') + originalError.status = 429 + + const mockClient = { + chat: { + completions: { + create: vi.fn(async () => { + throw originalError + }), + }, + }, + } as unknown as OpenAI + + const provider = new OpenAIModel({ api: 'chat', client: mockClient }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hi')] })] + + try { + for await (const _ of provider.stream(messages)) { + // Should not reach here + } + expect.fail('Should have thrown') + } catch (error) { + expect(error).toBeInstanceOf(ModelThrottledError) + const throttleError = error as ModelThrottledError + expect(throttleError.cause).toBe(originalError) + expect(throttleError.message).toBe('Request too large for gpt-5.4 on tokens per min') + } + }) + }) +}) diff --git a/strands-ts/src/models/openai/__tests__/mantle.test.ts b/strands-ts/src/models/openai/__tests__/mantle.test.ts new file mode 100644 index 0000000000..0daeccf955 --- /dev/null +++ b/strands-ts/src/models/openai/__tests__/mantle.test.ts @@ -0,0 +1,239 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import OpenAI from 'openai' +import { isNode } from '../../../__fixtures__/environment.js' +import { OpenAIModel } from '../index.js' + +vi.mock('openai', () => { + const mockConstructor = vi.fn(function (this: unknown) { + return {} + }) + return { + default: mockConstructor, + } +}) + +const getTokenProviderMock = vi.fn() +vi.mock('@aws/bedrock-token-generator', () => ({ + getTokenProvider: (...args: unknown[]) => getTokenProviderMock(...args), +})) + +const TEST_MODEL_ID = 'openai.gpt-oss-120b' +const TEST_TOKEN = 'bedrock-api-key-deadbeef&Version=1' + +function lastApiKeySetter(): () => Promise { + const calls = (OpenAI as unknown as { mock: { calls: unknown[][] } }).mock.calls + expect(calls.length).toBeGreaterThan(0) + const options = calls[calls.length - 1]![0] as { apiKey: () => Promise } + expect(typeof options.apiKey).toBe('function') + return options.apiKey +} + +describe('OpenAIModel bedrockMantleConfig', () => { + let provideTokenMock: ReturnType + + beforeEach(() => { + vi.clearAllMocks() + if (isNode) { + // Mantle pathway shouldn't look at OPENAI_API_KEY — guard against + // accidental env leakage by clearing it for the suite. + vi.stubEnv('OPENAI_API_KEY', '') + vi.stubEnv('AWS_REGION', '') + vi.stubEnv('AWS_DEFAULT_REGION', '') + } + provideTokenMock = vi.fn().mockResolvedValue(TEST_TOKEN) + getTokenProviderMock.mockReturnValue(provideTokenMock) + }) + + afterEach(() => { + vi.clearAllMocks() + if (isNode) { + vi.unstubAllEnvs() + } + }) + + describe('constructor wiring', () => { + it('sets baseURL and installs async apiKey setter that mints a bearer token', async () => { + new OpenAIModel({ + modelId: TEST_MODEL_ID, + bedrockMantleConfig: { region: 'us-east-1' }, + }) + + expect(OpenAI).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: 'https://bedrock-mantle.us-east-1.api.aws/v1', + apiKey: expect.any(Function), + }) + ) + + const apiKey = await lastApiKeySetter()() + expect(apiKey).toBe(TEST_TOKEN) + expect(getTokenProviderMock).toHaveBeenCalledWith({ region: 'us-east-1' }) + }) + + it('forwards optional credentials and expiresInSeconds to getTokenProvider', async () => { + const credentials = vi.fn() + new OpenAIModel({ + modelId: TEST_MODEL_ID, + bedrockMantleConfig: { + region: 'us-west-2', + credentials, + expiresInSeconds: 900, + }, + }) + + await lastApiKeySetter()() + + expect(getTokenProviderMock).toHaveBeenCalledWith({ + region: 'us-west-2', + credentials, + expiresInSeconds: 900, + }) + }) + + it('mints a fresh token on every apiKey setter call', async () => { + new OpenAIModel({ + modelId: TEST_MODEL_ID, + bedrockMantleConfig: { region: 'us-east-1' }, + }) + + const apiKey = lastApiKeySetter() + await apiKey() + await apiKey() + await apiKey() + + // The token provider is created once and reused, but it is invoked per call. + expect(getTokenProviderMock).toHaveBeenCalledTimes(1) + expect(provideTokenMock).toHaveBeenCalledTimes(3) + }) + + it('merges with other clientConfig fields while overriding baseURL and apiKey', () => { + const http = vi.fn() + new OpenAIModel({ + modelId: TEST_MODEL_ID, + clientConfig: { + timeout: 42, + fetch: http, + defaultHeaders: { 'X-Trace-Id': 'abc' }, + }, + bedrockMantleConfig: { region: 'us-east-1' }, + }) + + expect(OpenAI).toHaveBeenCalledWith( + expect.objectContaining({ + baseURL: 'https://bedrock-mantle.us-east-1.api.aws/v1', + apiKey: expect.any(Function), + timeout: 42, + fetch: http, + defaultHeaders: { 'X-Trace-Id': 'abc' }, + }) + ) + }) + + it('does not check OPENAI_API_KEY when bedrockMantleConfig is set', () => { + // env vars are cleared in beforeEach — this would normally throw, but the + // Mantle pathway has its own auth and must bypass the check. + expect( + () => new OpenAIModel({ modelId: TEST_MODEL_ID, bedrockMantleConfig: { region: 'us-east-1' } }) + ).not.toThrow() + }) + + it('works for api: "chat" as well as the default responses api', async () => { + new OpenAIModel({ + api: 'chat', + modelId: TEST_MODEL_ID, + bedrockMantleConfig: { region: 'us-east-1' }, + }) + const apiKey = await lastApiKeySetter()() + expect(apiKey).toBe(TEST_TOKEN) + }) + }) + + describe('validation', () => { + it('throws when bedrockMantleConfig is combined with a pre-built client', () => { + const client = {} as OpenAI + expect( + () => + new OpenAIModel({ + modelId: TEST_MODEL_ID, + client, + bedrockMantleConfig: { region: 'us-east-1' }, + }) + ).toThrow(/bedrockMantleConfig.*pre-built/) + }) + + it('throws when clientConfig.baseURL is set alongside bedrockMantleConfig', () => { + expect( + () => + new OpenAIModel({ + modelId: TEST_MODEL_ID, + clientConfig: { baseURL: 'https://example.invalid' }, + bedrockMantleConfig: { region: 'us-east-1' }, + }) + ).toThrow(/baseURL/) + }) + + it('throws when clientConfig.apiKey is set alongside bedrockMantleConfig', () => { + expect( + () => + new OpenAIModel({ + modelId: TEST_MODEL_ID, + clientConfig: { apiKey: 'sk-nope' }, + bedrockMantleConfig: { region: 'us-east-1' }, + }) + ).toThrow(/apiKey/) + }) + + it('throws when top-level apiKey is set alongside bedrockMantleConfig', () => { + expect( + () => + new OpenAIModel({ + modelId: TEST_MODEL_ID, + apiKey: 'sk-nope', + bedrockMantleConfig: { region: 'us-east-1' }, + }) + ).toThrow(/apiKey/) + }) + }) + + describe('region resolution', () => { + it('throws when no region is available from config or env', () => { + expect(() => new OpenAIModel({ modelId: TEST_MODEL_ID, bedrockMantleConfig: {} })).toThrow( + /could not resolve an AWS region/ + ) + }) + + if (isNode) { + it('falls back to AWS_REGION env var', async () => { + vi.stubEnv('AWS_REGION', 'eu-west-1') + new OpenAIModel({ modelId: TEST_MODEL_ID, bedrockMantleConfig: {} }) + await lastApiKeySetter()() + expect(OpenAI).toHaveBeenCalledWith( + expect.objectContaining({ baseURL: 'https://bedrock-mantle.eu-west-1.api.aws/v1' }) + ) + expect(getTokenProviderMock).toHaveBeenCalledWith({ region: 'eu-west-1' }) + }) + + it('falls back to AWS_DEFAULT_REGION when AWS_REGION is unset', async () => { + vi.stubEnv('AWS_DEFAULT_REGION', 'ap-southeast-2') + new OpenAIModel({ modelId: TEST_MODEL_ID, bedrockMantleConfig: {} }) + await lastApiKeySetter()() + expect(getTokenProviderMock).toHaveBeenCalledWith({ region: 'ap-southeast-2' }) + }) + + it('prefers explicit region over env vars', async () => { + vi.stubEnv('AWS_REGION', 'eu-west-1') + new OpenAIModel({ modelId: TEST_MODEL_ID, bedrockMantleConfig: { region: 'us-east-1' } }) + await lastApiKeySetter()() + expect(getTokenProviderMock).toHaveBeenCalledWith({ region: 'us-east-1' }) + }) + } + }) + + describe('token minting errors', () => { + it('wraps token provider failures with actionable context', async () => { + provideTokenMock.mockRejectedValueOnce(new Error('no credentials in chain')) + new OpenAIModel({ modelId: TEST_MODEL_ID, bedrockMantleConfig: { region: 'us-east-1' } }) + await expect(lastApiKeySetter()()).rejects.toThrow(/us-east-1/) + }) + }) +}) diff --git a/strands-ts/src/models/openai/__tests__/responses.test.ts b/strands-ts/src/models/openai/__tests__/responses.test.ts new file mode 100644 index 0000000000..bfceee935f --- /dev/null +++ b/strands-ts/src/models/openai/__tests__/responses.test.ts @@ -0,0 +1,758 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import OpenAI from 'openai' +import { isNode } from '../../../__fixtures__/environment.js' +import { OpenAIModel } from '../index.js' +import { ContextWindowOverflowError, ModelThrottledError } from '../../../errors.js' +import { collectIterator } from '../../../__fixtures__/model-test-helpers.js' +import { Message, TextBlock, ToolUseBlock, ToolResultBlock } from '../../../types/messages.js' +import { ImageBlock, DocumentBlock } from '../../../types/media.js' +import { StateStore } from '../../../state-store.js' +import { logger } from '../../../logging/logger.js' + +/** + * Build a mock OpenAI client whose `responses.create` returns the given async generator. + * The last request passed to `create` is captured on `capture.request`. + */ +function createMockClient(streamGenerator: () => AsyncGenerator, capture: { request?: any } = {}): OpenAI { + return { + responses: { + create: vi.fn(async (request: any) => { + capture.request = request + return streamGenerator() + }), + }, + } as any +} + +// Mock the OpenAI SDK +vi.mock('openai', () => { + const mockConstructor = vi.fn(function (this: any) { + return {} + }) + return { + default: mockConstructor, + } +}) + +describe("OpenAIModel (api: 'responses')", () => { + beforeEach(() => { + vi.clearAllMocks() + vi.restoreAllMocks() + if (isNode) { + vi.stubEnv('OPENAI_API_KEY', 'sk-test-env') + } + }) + + afterEach(() => { + vi.clearAllMocks() + if (isNode) { + vi.unstubAllEnvs() + } + }) + + describe('constructor', () => { + it('uses API key from constructor parameter', () => { + new OpenAIModel({ api: 'responses', modelId: 'gpt-4o', apiKey: 'sk-explicit' }) + expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: 'sk-explicit' })) + }) + + if (isNode) { + it('uses API key from environment variable', () => { + vi.stubEnv('OPENAI_API_KEY', 'sk-from-env') + new OpenAIModel({ api: 'responses', modelId: 'gpt-4o' }) + expect(OpenAI).toHaveBeenCalled() + }) + } + + it('throws error when no API key is available', () => { + if (isNode) { + vi.stubEnv('OPENAI_API_KEY', '') + } + expect(() => new OpenAIModel({ api: 'responses', modelId: 'gpt-4o' })).toThrow(/OpenAI API key is required/) + }) + + it('uses provided client instance and skips OpenAI constructor', () => { + vi.clearAllMocks() + const client = {} as OpenAI + const model = new OpenAIModel({ api: 'responses', client }) + expect(OpenAI).not.toHaveBeenCalled() + expect(model).toBeDefined() + }) + + it('does not require API key when client is provided', () => { + if (isNode) { + vi.stubEnv('OPENAI_API_KEY', '') + } + const client = {} as OpenAI + expect(() => new OpenAIModel({ api: 'responses', client })).not.toThrow() + }) + }) + + describe('stateful', () => { + it('defaults to false', () => { + const model = new OpenAIModel({ api: 'responses', client: {} as OpenAI }) + expect(model.stateful).toBe(false) + }) + + it('returns true when explicitly enabled', () => { + const model = new OpenAIModel({ api: 'responses', client: {} as OpenAI, stateful: true }) + expect(model.stateful).toBe(true) + }) + + it('is construction-only and cannot be changed via updateConfig', () => { + const model = new OpenAIModel({ api: 'responses', client: {} as OpenAI, stateful: false }) + const warnSpy = vi.spyOn(logger, 'warn') + expect(model.stateful).toBe(false) + model.updateConfig({ stateful: true }) + expect(model.stateful).toBe(false) + expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining("'stateful' is construction-only")) + warnSpy.mockRestore() + }) + }) + + describe('updateConfig / getConfig', () => { + it('merges config without clobbering unspecified fields', () => { + const model = new OpenAIModel({ + api: 'responses', + client: {} as OpenAI, + modelId: 'gpt-4o', + temperature: 0.5, + maxTokens: 1024, + }) + model.updateConfig({ temperature: 0.9 }) + expect(model.getConfig()).toMatchObject({ + modelId: 'gpt-4o', + temperature: 0.9, + maxTokens: 1024, + }) + }) + }) + + describe('managed params warning', () => { + it('warns on construction when params contains provider-managed keys', () => { + const warnSpy = vi.spyOn(logger, 'warn') + new OpenAIModel({ api: 'responses', client: {} as OpenAI, params: { model: 'bad', store: false } }) + expect(warnSpy).toHaveBeenCalledTimes(2) + expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining("'model'")) + expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining("'store'")) + warnSpy.mockRestore() + }) + + it('warns on updateConfig when params contains provider-managed keys', () => { + const model = new OpenAIModel({ api: 'responses', client: {} as OpenAI }) + const warnSpy = vi.spyOn(logger, 'warn') + model.updateConfig({ params: { stream: true } }) + expect(warnSpy).toHaveBeenCalledTimes(1) + expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining("'stream'")) + warnSpy.mockRestore() + }) + + it('does not warn when params contains only non-managed keys', () => { + const warnSpy = vi.spyOn(logger, 'warn') + new OpenAIModel({ api: 'responses', client: {} as OpenAI, params: { reasoning: { summary: 'auto' } } }) + expect(warnSpy).not.toHaveBeenCalled() + warnSpy.mockRestore() + }) + }) + + describe('request formatting', () => { + const mkUserMessage = () => new Message({ role: 'user', content: [new TextBlock('Hi')] }) + + async function runOnce( + modelOptions: Omit[0], { api?: 'responses' }>, 'api'> = {}, + messages = [mkUserMessage()], + streamOptions: Parameters[1] = undefined + ): Promise { + const capture: { request?: any } = {} + const client = createMockClient(async function* () { + yield { type: 'response.created', response: { id: 'resp_123' } } + yield { type: 'response.completed', response: { usage: undefined } } + }, capture) + const model = new OpenAIModel({ api: 'responses', client, ...modelOptions }) + await collectIterator(model.stream(messages, streamOptions)) + return capture.request + } + + it('includes model, input, stream, and store=false by default', async () => { + const req = await runOnce() + expect(req.model).toBe('gpt-5.4') + expect(req.stream).toBe(true) + expect(req.store).toBe(false) + expect(Array.isArray(req.input)).toBe(true) + }) + + it('sets store=true when stateful is enabled', async () => { + const req = await runOnce({ stateful: true }) + expect(req.store).toBe(true) + }) + + it('chains previous_response_id when stateful and modelState has responseId', async () => { + const modelState = new StateStore({ responseId: 'resp_prev' }) + const req = await runOnce({ stateful: true }, [mkUserMessage()], { modelState }) + expect(req.previous_response_id).toBe('resp_prev') + }) + + it('omits previous_response_id when stateful is disabled, even with responseId in modelState', async () => { + const modelState = new StateStore({ responseId: 'resp_prev' }) + const req = await runOnce({}, [mkUserMessage()], { modelState }) + expect(req.previous_response_id).toBeUndefined() + }) + + it('maps systemPrompt string to instructions', async () => { + const req = await runOnce({}, [mkUserMessage()], { systemPrompt: 'Be helpful.' }) + expect(req.instructions).toBe('Be helpful.') + }) + + it('merges toolSpecs with built-in tools from params', async () => { + const req = await runOnce({ params: { tools: [{ type: 'web_search' }] } }, [mkUserMessage()], { + toolSpecs: [ + { + name: 'calc', + description: 'calculator', + inputSchema: { type: 'object', properties: {} }, + }, + ], + }) + expect(req.tools).toEqual([ + { type: 'web_search' }, + { + type: 'function', + name: 'calc', + description: 'calculator', + parameters: { type: 'object', properties: {} }, + strict: null, + }, + ]) + }) + + it('maps tool_choice variants', async () => { + const toolSpecs = [{ name: 'calc', description: 'd', inputSchema: {} }] + const autoReq = await runOnce({}, [mkUserMessage()], { toolSpecs, toolChoice: { auto: {} } }) + expect(autoReq.tool_choice).toBe('auto') + + const anyReq = await runOnce({}, [mkUserMessage()], { toolSpecs, toolChoice: { any: {} } }) + expect(anyReq.tool_choice).toBe('required') + + const toolReq = await runOnce({}, [mkUserMessage()], { + toolSpecs, + toolChoice: { tool: { name: 'calc' } }, + }) + expect(toolReq.tool_choice).toEqual({ type: 'function', name: 'calc' }) + }) + + it('formats temperature, maxTokens→max_output_tokens, and topP', async () => { + const req = await runOnce({ temperature: 0.3, maxTokens: 512, topP: 0.8 }) + expect(req.temperature).toBe(0.3) + expect(req.max_output_tokens).toBe(512) + expect(req.top_p).toBe(0.8) + }) + + it('passes through extra params fields to the request', async () => { + const req = await runOnce({ params: { reasoning: { summary: 'auto' } } }) + expect(req.reasoning).toEqual({ summary: 'auto' }) + }) + + it('provider-managed fields in params are overridden and cannot take effect', async () => { + const warnSpy = vi.spyOn(logger, 'warn') + const req = await runOnce({ + modelId: 'gpt-4o', + stateful: true, + params: { model: 'attacker-model', input: 'hijacked', stream: false, store: false }, + }) + expect(req.model).toBe('gpt-4o') + expect(req.stream).toBe(true) + expect(req.store).toBe(true) + expect(Array.isArray(req.input)).toBe(true) + warnSpy.mockRestore() + }) + + it('emits tool_use and tool_result as separate top-level items', async () => { + const messages = [ + new Message({ role: 'user', content: [new TextBlock('run it')] }), + new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'calc', toolUseId: 'call_1', input: { expr: '2+2' } })], + }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'call_1', + status: 'success', + content: [new TextBlock('4')], + }), + ], + }), + ] + const req = await runOnce({}, messages) + const functionCall = req.input.find((i: any) => i.type === 'function_call') + const functionOutput = req.input.find((i: any) => i.type === 'function_call_output') + expect(functionCall).toMatchObject({ + type: 'function_call', + call_id: 'call_1', + name: 'calc', + arguments: JSON.stringify({ expr: '2+2' }), + }) + expect(functionOutput).toMatchObject({ + type: 'function_call_output', + call_id: 'call_1', + output: '4', + }) + }) + + it('prefixes errored tool results with [ERROR]', async () => { + const messages = [ + new Message({ role: 'user', content: [new TextBlock('x')] }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 't1', + status: 'error', + content: [new TextBlock('boom')], + }), + ], + }), + ] + const req = await runOnce({}, messages) + const out = req.input.find((i: any) => i.type === 'function_call_output') + expect(out.output).toBe('[ERROR] boom') + }) + + it('emits an array output with input_image when a tool result carries image bytes', async () => { + const imageBytes = new Uint8Array([1, 2, 3, 4]) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('fetch')] }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'img_tool', + status: 'success', + content: [ + new TextBlock('here is the image'), + new ImageBlock({ format: 'png', source: { bytes: imageBytes } }), + ], + }), + ], + }), + ] + const req = await runOnce({}, messages) + const out = req.input.find((i: any) => i.type === 'function_call_output') + expect(Array.isArray(out.output)).toBe(true) + expect(out.output).toEqual([ + { type: 'input_text', text: 'here is the image' }, + { type: 'input_image', image_url: expect.stringMatching(/^data:image\/png;base64,/) }, + ]) + }) + + it('emits an array output with input_file when a tool result carries a document', async () => { + const docBytes = new Uint8Array([5, 6, 7, 8]) + const messages = [ + new Message({ role: 'user', content: [new TextBlock('read')] }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'doc_tool', + status: 'success', + content: [new DocumentBlock({ name: 'report.pdf', format: 'pdf', source: { bytes: docBytes } })], + }), + ], + }), + ] + const req = await runOnce({}, messages) + const out = req.input.find((i: any) => i.type === 'function_call_output') + expect(Array.isArray(out.output)).toBe(true) + expect(out.output).toEqual([ + { + type: 'input_file', + file_data: expect.stringMatching(/^data:application\/pdf;base64,/), + filename: 'report.pdf', + }, + ]) + }) + + it('keeps tool result output as a plain string when only text is present', async () => { + const messages = [ + new Message({ role: 'user', content: [new TextBlock('ping')] }), + new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'text_tool', + status: 'success', + content: [new TextBlock('pong')], + }), + ], + }), + ] + const req = await runOnce({}, messages) + const out = req.input.find((i: any) => i.type === 'function_call_output') + expect(typeof out.output).toBe('string') + expect(out.output).toBe('pong') + }) + }) + + describe('stream event mapping', () => { + it('captures responseId on response.created when stateful', async () => { + const modelState = new StateStore() + const client = createMockClient(async function* () { + yield { type: 'response.created', response: { id: 'resp_abc' } } + yield { type: 'response.completed', response: {} } + }) + const model = new OpenAIModel({ api: 'responses', client, stateful: true }) + await collectIterator( + model.stream([new Message({ role: 'user', content: [new TextBlock('hi')] })], { modelState }) + ) + expect(modelState.get('responseId')).toBe('resp_abc') + }) + + it('does NOT capture responseId when stateful is disabled', async () => { + const modelState = new StateStore() + const client = createMockClient(async function* () { + yield { type: 'response.created', response: { id: 'resp_abc' } } + yield { type: 'response.completed', response: {} } + }) + const model = new OpenAIModel({ api: 'responses', client }) + await collectIterator( + model.stream([new Message({ role: 'user', content: [new TextBlock('hi')] })], { modelState }) + ) + expect(modelState.get('responseId')).toBeUndefined() + }) + + it('emits text deltas inside a content block', async () => { + const client = createMockClient(async function* () { + yield { type: 'response.created', response: { id: 'r' } } + yield { type: 'response.output_text.delta', delta: 'Hello' } + yield { type: 'response.output_text.delta', delta: ' world' } + yield { type: 'response.completed', response: {} } + }) + const model = new OpenAIModel({ api: 'responses', client }) + const events = await collectIterator(model.stream([new Message({ role: 'user', content: [new TextBlock('x')] })])) + const types = events.map((e: any) => e.type) + expect(types).toEqual([ + 'modelMessageStartEvent', + 'modelContentBlockStartEvent', + 'modelContentBlockDeltaEvent', + 'modelContentBlockDeltaEvent', + 'modelContentBlockStopEvent', + 'modelMessageStopEvent', + ]) + const deltas = events.filter((e: any) => e.type === 'modelContentBlockDeltaEvent').map((e: any) => e.delta) + expect(deltas).toEqual([ + { type: 'textDelta', text: 'Hello' }, + { type: 'textDelta', text: ' world' }, + ]) + }) + + it('switches content blocks between reasoning and text', async () => { + const client = createMockClient(async function* () { + yield { type: 'response.created', response: { id: 'r' } } + yield { type: 'response.reasoning_text.delta', delta: 'thinking...' } + yield { type: 'response.output_text.delta', delta: 'answer' } + yield { type: 'response.completed', response: {} } + }) + const model = new OpenAIModel({ api: 'responses', client }) + const events = await collectIterator(model.stream([new Message({ role: 'user', content: [new TextBlock('x')] })])) + const types = events.map((e: any) => e.type) + expect(types).toEqual([ + 'modelMessageStartEvent', + 'modelContentBlockStartEvent', + 'modelContentBlockDeltaEvent', // reasoning + 'modelContentBlockStopEvent', + 'modelContentBlockStartEvent', + 'modelContentBlockDeltaEvent', // text + 'modelContentBlockStopEvent', + 'modelMessageStopEvent', + ]) + }) + + it('emits tool call triplet after stream close and sets stopReason=toolUse', async () => { + const client = createMockClient(async function* () { + yield { type: 'response.created', response: { id: 'r' } } + yield { + type: 'response.output_item.added', + item: { type: 'function_call', id: 'item_1', call_id: 'call_1', name: 'calc' }, + } + yield { + type: 'response.function_call_arguments.delta', + item_id: 'item_1', + delta: '{"a":', + } + yield { + type: 'response.function_call_arguments.delta', + item_id: 'item_1', + delta: '1}', + } + yield { + type: 'response.function_call_arguments.done', + item_id: 'item_1', + arguments: '{"a":1}', + } + yield { type: 'response.completed', response: {} } + }) + const model = new OpenAIModel({ api: 'responses', client }) + const events = await collectIterator(model.stream([new Message({ role: 'user', content: [new TextBlock('x')] })])) + const startEvent = events.find( + (e: any) => e.type === 'modelContentBlockStartEvent' && e.start?.type === 'toolUseStart' + ) as any + expect(startEvent?.start).toEqual({ + type: 'toolUseStart', + name: 'calc', + toolUseId: 'call_1', + }) + const deltaEvent = events.find( + (e: any) => e.type === 'modelContentBlockDeltaEvent' && e.delta?.type === 'toolUseInputDelta' + ) as any + expect(deltaEvent?.delta).toEqual({ type: 'toolUseInputDelta', input: '{"a":1}' }) + const stopEvent = events.find((e: any) => e.type === 'modelMessageStopEvent') as any + expect(stopEvent?.stopReason).toBe('toolUse') + }) + + it('maps response.incomplete with max_output_tokens to stopReason=maxTokens', async () => { + const client = createMockClient(async function* () { + yield { type: 'response.created', response: { id: 'r' } } + yield { type: 'response.output_text.delta', delta: 'partial' } + yield { + type: 'response.incomplete', + response: { + incomplete_details: { reason: 'max_output_tokens' }, + usage: { input_tokens: 10, output_tokens: 5, total_tokens: 15 }, + }, + } + }) + const model = new OpenAIModel({ api: 'responses', client }) + const events = await collectIterator(model.stream([new Message({ role: 'user', content: [new TextBlock('x')] })])) + const stop = events.find((e: any) => e.type === 'modelMessageStopEvent') as any + expect(stop?.stopReason).toBe('maxTokens') + const metadata = events.find((e: any) => e.type === 'modelMetadataEvent') as any + expect(metadata?.usage).toEqual({ inputTokens: 10, outputTokens: 5, totalTokens: 15 }) + }) + + it('emits URL citation delta from response.output_text.annotation.added', async () => { + const client = createMockClient(async function* () { + yield { type: 'response.created', response: { id: 'r' } } + yield { type: 'response.output_text.delta', delta: 'The answer is here.' } + yield { + type: 'response.output_text.annotation.added', + annotation: { + type: 'url_citation', + url: 'https://example.com', + title: 'Example', + cited_text: 'here', + }, + } + yield { type: 'response.completed', response: {} } + }) + const model = new OpenAIModel({ api: 'responses', client }) + const events = await collectIterator(model.stream([new Message({ role: 'user', content: [new TextBlock('x')] })])) + const citation = events.find( + (e: any) => e.type === 'modelContentBlockDeltaEvent' && e.delta?.type === 'citationsDelta' + ) as any + expect(citation?.delta.citations[0]).toMatchObject({ + location: { type: 'web', url: 'https://example.com' }, + source: 'https://example.com', + title: 'Example', + }) + }) + + it('closes the text block before a citation, producing separate blocks when stream ends after citation', async () => { + const client = createMockClient(async function* () { + yield { type: 'response.created', response: { id: 'r' } } + yield { type: 'response.output_text.delta', delta: 'Before citation' } + yield { + type: 'response.output_text.annotation.added', + annotation: { + type: 'url_citation', + url: 'https://example.com', + title: 'Source', + cited_text: 'cited', + }, + } + yield { type: 'response.completed', response: {} } + }) + const model = new OpenAIModel({ api: 'responses', client }) + const events = await collectIterator(model.stream([new Message({ role: 'user', content: [new TextBlock('x')] })])) + const types = events.map((e: any) => e.type) + expect(types).toEqual([ + 'modelMessageStartEvent', + // Text block — closed before citation + 'modelContentBlockStartEvent', + 'modelContentBlockDeltaEvent', + 'modelContentBlockStopEvent', + // Citation block + 'modelContentBlockStartEvent', + 'modelContentBlockDeltaEvent', + 'modelContentBlockStopEvent', + 'modelMessageStopEvent', + ]) + const deltas = events.filter((e: any) => e.type === 'modelContentBlockDeltaEvent').map((e: any) => e.delta.type) + expect(deltas).toEqual(['textDelta', 'citationsDelta']) + }) + + it('closes the text block before a citation and opens a new text block after', async () => { + const client = createMockClient(async function* () { + yield { type: 'response.created', response: { id: 'r' } } + yield { type: 'response.output_text.delta', delta: 'Before ' } + yield { + type: 'response.output_text.annotation.added', + annotation: { + type: 'url_citation', + url: 'https://example.com', + title: 'Source', + cited_text: 'cited', + }, + } + yield { type: 'response.output_text.delta', delta: ' after' } + yield { type: 'response.completed', response: {} } + }) + const model = new OpenAIModel({ api: 'responses', client }) + const events = await collectIterator(model.stream([new Message({ role: 'user', content: [new TextBlock('x')] })])) + const types = events.map((e: any) => e.type) + expect(types).toEqual([ + 'modelMessageStartEvent', + // First text block + 'modelContentBlockStartEvent', + 'modelContentBlockDeltaEvent', + 'modelContentBlockStopEvent', + // Citation block + 'modelContentBlockStartEvent', + 'modelContentBlockDeltaEvent', + 'modelContentBlockStopEvent', + // New text block after citation + 'modelContentBlockStartEvent', + 'modelContentBlockDeltaEvent', + 'modelContentBlockStopEvent', + 'modelMessageStopEvent', + ]) + const deltas = events.filter((e: any) => e.type === 'modelContentBlockDeltaEvent').map((e: any) => e.delta.type) + expect(deltas).toEqual(['textDelta', 'citationsDelta', 'textDelta']) + }) + + it('keeps consecutive citations in the same block without extra stop/start', async () => { + const client = createMockClient(async function* () { + yield { type: 'response.created', response: { id: 'r' } } + yield { + type: 'response.output_text.annotation.added', + annotation: { type: 'url_citation', url: 'https://a.com', title: 'A', cited_text: 'a' }, + } + yield { + type: 'response.output_text.annotation.added', + annotation: { type: 'url_citation', url: 'https://b.com', title: 'B', cited_text: 'b' }, + } + yield { type: 'response.completed', response: {} } + }) + const model = new OpenAIModel({ api: 'responses', client }) + const events = await collectIterator(model.stream([new Message({ role: 'user', content: [new TextBlock('x')] })])) + const types = events.map((e: any) => e.type) + expect(types).toEqual([ + 'modelMessageStartEvent', + 'modelContentBlockStartEvent', + 'modelContentBlockDeltaEvent', + 'modelContentBlockDeltaEvent', + 'modelContentBlockStopEvent', + 'modelMessageStopEvent', + ]) + }) + + it('handles text → citation → text → citation → text with separate blocks each time', async () => { + const client = createMockClient(async function* () { + yield { type: 'response.created', response: { id: 'r' } } + yield { type: 'response.output_text.delta', delta: 'intro ' } + yield { + type: 'response.output_text.annotation.added', + annotation: { type: 'url_citation', url: 'https://1.com', title: '1', cited_text: 'c1' }, + } + yield { type: 'response.output_text.delta', delta: 'middle ' } + yield { + type: 'response.output_text.annotation.added', + annotation: { type: 'url_citation', url: 'https://2.com', title: '2', cited_text: 'c2' }, + } + yield { type: 'response.output_text.delta', delta: 'end' } + yield { type: 'response.completed', response: {} } + }) + const model = new OpenAIModel({ api: 'responses', client }) + const events = await collectIterator(model.stream([new Message({ role: 'user', content: [new TextBlock('x')] })])) + const deltaTypes = events + .filter((e: any) => e.type === 'modelContentBlockDeltaEvent') + .map((e: any) => e.delta.type) + expect(deltaTypes).toEqual(['textDelta', 'citationsDelta', 'textDelta', 'citationsDelta', 'textDelta']) + // 5 content blocks = 5 start + 5 stop events + const starts = events.filter((e: any) => e.type === 'modelContentBlockStartEvent') + const stops = events.filter((e: any) => e.type === 'modelContentBlockStopEvent') + expect(starts).toHaveLength(5) + expect(stops).toHaveLength(5) + }) + }) + + describe('error mapping', () => { + it('wraps 429 as ModelThrottledError', async () => { + const client: any = { + responses: { + create: vi.fn(async () => { + const err: any = new Error('Too many requests') + err.status = 429 + throw err + }), + }, + } + const model = new OpenAIModel({ api: 'responses', client }) + await expect( + collectIterator(model.stream([new Message({ role: 'user', content: [new TextBlock('x')] })])) + ).rejects.toBeInstanceOf(ModelThrottledError) + }) + + it('wraps context_length_exceeded as ContextWindowOverflowError', async () => { + const client: any = { + responses: { + create: vi.fn(async () => { + const err: any = new Error('This model has a maximum context length of 8k.') + err.code = 'context_length_exceeded' + throw err + }), + }, + } + const model = new OpenAIModel({ api: 'responses', client }) + await expect( + collectIterator(model.stream([new Message({ role: 'user', content: [new TextBlock('x')] })])) + ).rejects.toBeInstanceOf(ContextWindowOverflowError) + }) + + it.each([ + 'maximum context length exceeded', + 'context_length_exceeded', + 'too many tokens', + 'context length', + 'Input is too long for requested model', + 'input length and `max_tokens` exceed context limit', + 'too many total text bytes', + ])('wraps context overflow message pattern "%s" as ContextWindowOverflowError', async (message) => { + const client: any = { + responses: { + create: vi.fn(async () => { + throw new Error(message) + }), + }, + } + const model = new OpenAIModel({ api: 'responses', client }) + await expect( + collectIterator(model.stream([new Message({ role: 'user', content: [new TextBlock('x')] })])) + ).rejects.toBeInstanceOf(ContextWindowOverflowError) + }) + + it('rethrows unknown errors untouched', async () => { + const client: any = { + responses: { + create: vi.fn(async () => { + throw new Error('some other failure') + }), + }, + } + const model = new OpenAIModel({ api: 'responses', client }) + await expect( + collectIterator(model.stream([new Message({ role: 'user', content: [new TextBlock('x')] })])) + ).rejects.toThrow('some other failure') + }) + }) +}) diff --git a/strands-ts/src/models/openai/chat-adapter.ts b/strands-ts/src/models/openai/chat-adapter.ts new file mode 100644 index 0000000000..dbe24a82ab --- /dev/null +++ b/strands-ts/src/models/openai/chat-adapter.ts @@ -0,0 +1,462 @@ +/** + * Chat Completions API adapter for the OpenAI model provider. + * + * @internal + */ + +import OpenAI from 'openai' +import type { ChatCompletionContentPartText } from 'openai/resources/index.mjs' +import type { Message, StopReason, ToolResultBlock } from '../../types/messages.js' +import type { ImageBlock, DocumentBlock } from '../../types/media.js' +import { encodeBase64 } from '../../types/media.js' +import { toMimeType } from '../../mime.js' +import type { ModelStreamEvent } from '../streaming.js' +import type { StreamOptions } from '../model.js' +import { logger } from '../../logging/logger.js' +import { MODEL_DEFAULTS } from '../defaults.js' +import { formatImageDataUrl, warnManagedParams as warnManagedParamsShared } from './formatting.js' +import type { ChatStreamState, OpenAIChatConfig } from './types.js' + +export const DEFAULT_CHAT_MODEL_ID = MODEL_DEFAULTS.openai.modelId + +const MANAGED_PARAMS: ReadonlySet = new Set(['model', 'messages', 'stream', 'stream_options']) + +/** + * Logs a warning for each chat-managed key present in `params`. + * + * @internal + */ +export function warnManagedParams(params: Record | undefined): void { + warnManagedParamsShared(params, MANAGED_PARAMS) +} + +type OpenAIChatChoice = { + delta?: { + role?: string + content?: string + tool_calls?: Array<{ + index: number + id?: string + type?: string + function?: { + name?: string + arguments?: string + } + }> + } + finish_reason?: string + index: number +} + +/** + * Builds a Chat Completions streaming request body. + * + * @internal + */ +export function formatChatRequest( + config: OpenAIChatConfig, + messages: Message[], + options?: StreamOptions +): OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming { + // User `params` are spread first so provider-managed fields always win. + // The managed-params warning fires at config time to surface the collision. + const request = { + ...(config.params ?? {}), + model: config.modelId ?? DEFAULT_CHAT_MODEL_ID, + messages: [] as OpenAI.Chat.Completions.ChatCompletionMessageParam[], + stream: true as const, + stream_options: { include_usage: true }, + } as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming + + if (options?.systemPrompt !== undefined) { + if (typeof options.systemPrompt === 'string') { + if (options.systemPrompt.trim().length > 0) { + request.messages.push({ role: 'system', content: options.systemPrompt }) + } + } else if (Array.isArray(options.systemPrompt) && options.systemPrompt.length > 0) { + const textBlocks: string[] = [] + let hasCachePoints = false + let hasGuardContent = false + + for (const block of options.systemPrompt) { + if (block.type === 'textBlock') { + textBlocks.push(block.text) + } else if (block.type === 'cachePointBlock') { + hasCachePoints = true + } else if (block.type === 'guardContentBlock') { + hasGuardContent = true + } + } + + if (hasCachePoints) { + logger.warn('cache points are not supported in openai system prompts, ignoring cache points') + } + if (hasGuardContent) { + logger.warn('guard content is not supported in openai system prompts, removing guard content block') + } + + if (textBlocks.length > 0) { + request.messages.push({ role: 'system', content: textBlocks.join('') }) + } + } + } + + request.messages.push(...formatChatMessages(messages)) + + if (config.temperature !== undefined) request.temperature = config.temperature + if (config.maxTokens !== undefined) request.max_completion_tokens = config.maxTokens + if (config.topP !== undefined) request.top_p = config.topP + if (config.frequencyPenalty !== undefined) request.frequency_penalty = config.frequencyPenalty + if (config.presencePenalty !== undefined) request.presence_penalty = config.presencePenalty + + if (options?.toolSpecs && options.toolSpecs.length > 0) { + request.tools = options.toolSpecs.map((spec) => { + if (!spec.name || !spec.description) { + throw new Error('Tool specification must have both name and description') + } + return { + type: 'function' as const, + function: { + name: spec.name, + description: spec.description, + parameters: spec.inputSchema as Record, + }, + } + }) + + if (options.toolChoice) { + if ('auto' in options.toolChoice) { + request.tool_choice = 'auto' + } else if ('any' in options.toolChoice) { + request.tool_choice = 'required' + } else if ('tool' in options.toolChoice) { + request.tool_choice = { + type: 'function', + function: { name: options.toolChoice.tool.name }, + } + } + } + } + + if ('n' in request && request.n !== undefined && request.n !== null && request.n > 1) { + throw new Error('Streaming with n > 1 is not supported') + } + + return request +} + +/** + * Converts SDK messages into Chat Completions message params. Tool result blocks + * are split out into separate `tool`-role messages; media inside tool results is + * hoisted into a following user-role message (OpenAI restricts media to user role). + */ +function formatChatMessages(messages: Message[]): OpenAI.Chat.Completions.ChatCompletionMessageParam[] { + const openAIMessages: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = [] + + for (const message of messages) { + if (message.role === 'user') { + const toolResults = message.content.filter((b) => b.type === 'toolResultBlock') + const otherContent = message.content.filter((b) => b.type !== 'toolResultBlock') + + if (otherContent.length > 0) { + const contentParts: OpenAI.Chat.Completions.ChatCompletionContentPart[] = [] + + for (const block of otherContent) { + switch (block.type) { + case 'textBlock': { + contentParts.push({ type: 'text', text: block.text }) + break + } + case 'imageBlock': { + const formatted = formatImageContentPart(block as ImageBlock) + if (formatted) contentParts.push(formatted) + break + } + case 'documentBlock': { + const docBlock = block as DocumentBlock + switch (docBlock.source.type) { + case 'documentSourceBytes': { + const mimeType = toMimeType(docBlock.format) || `application/${docBlock.format}` + const base64 = encodeBase64(docBlock.source.bytes) + contentParts.push({ + type: 'file', + file: { + file_data: `data:${mimeType};base64,${base64}`, + filename: docBlock.name, + }, + }) + break + } + case 'documentSourceText': { + logger.warn( + 'source_type= | openai does not support text document sources directly | converting to string content' + ) + contentParts.push({ type: 'text', text: docBlock.source.text }) + break + } + case 'documentSourceContentBlock': { + contentParts.push( + ...docBlock.source.content.map((b) => ({ + type: 'text', + text: b.text, + })) + ) + break + } + default: { + logger.warn( + `source_type=<${docBlock.source.type}> | openai only supports text content in user messages | skipping document block` + ) + break + } + } + break + } + default: { + logger.warn(`block_type=<${block.type}> | unsupported content type in openai user message | skipping`) + break + } + } + } + + if (contentParts.length > 0) { + openAIMessages.push({ role: 'user', content: contentParts }) + } + } + + const userMessagesWithMedia: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = [] + + for (const toolResult of toolResults) { + const [textContent, imageParts] = splitToolResultMedia(toolResult) + + if (imageParts.length > 0) { + logger.warn( + `tool_call_id=<${toolResult.toolUseId}> | moving images from tool result to separate user message for openai compatibility` + ) + } + + const effectiveTextContent = + textContent.trim().length === 0 && imageParts.length > 0 + ? 'Tool successfully returned an image. The image is being provided in the following user message.' + : textContent + + if (!effectiveTextContent || effectiveTextContent.trim().length === 0) { + throw new Error( + `Tool result for toolUseId "${toolResult.toolUseId}" has empty content. ` + + 'OpenAI requires tool messages to have non-empty content.' + ) + } + + const finalContent = toolResult.status === 'error' ? `[ERROR] ${effectiveTextContent}` : effectiveTextContent + + openAIMessages.push({ + role: 'tool', + tool_call_id: toolResult.toolUseId, + content: finalContent, + }) + + if (imageParts.length > 0) { + userMessagesWithMedia.push({ role: 'user', content: imageParts }) + } + } + + openAIMessages.push(...userMessagesWithMedia) + } else { + const toolUseCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] = [] + const textParts: string[] = [] + + for (const block of message.content) { + switch (block.type) { + case 'textBlock': { + textParts.push(block.text) + break + } + case 'toolUseBlock': { + try { + toolUseCalls.push({ + id: block.toolUseId, + type: 'function', + function: { + name: block.name, + arguments: JSON.stringify(block.input), + }, + }) + } catch (error: unknown) { + if (error instanceof Error) { + throw new Error(`Failed to serialize tool input for "${block.name}`, error) + } + throw error + } + break + } + case 'reasoningBlock': { + if (block.text) { + logger.warn('block_type= | reasoning blocks not supported by openai | converting to text') + textParts.push(block.text) + } + break + } + default: { + logger.warn(`block_type=<${block.type}> | unsupported content type in openai assistant message | skipping`) + } + } + } + + const textContent = textParts.join('').trim() + const assistantMessage: OpenAI.Chat.Completions.ChatCompletionAssistantMessageParam = { + role: 'assistant', + content: textContent, + } + if (toolUseCalls.length > 0) { + assistantMessage.tool_calls = toolUseCalls + } + if (textContent.length > 0 || toolUseCalls.length > 0) { + openAIMessages.push(assistantMessage) + } + } + } + + return openAIMessages +} + +function formatImageContentPart( + imageBlock: ImageBlock +): OpenAI.Chat.Completions.ChatCompletionContentPartImage | undefined { + const url = formatImageDataUrl(imageBlock) + if (!url) return undefined + return { type: 'image_url', image_url: { url } } +} + +function splitToolResultMedia( + toolResult: ToolResultBlock +): [string, OpenAI.Chat.Completions.ChatCompletionContentPart[]] { + const textParts: string[] = [] + const imageParts: OpenAI.Chat.Completions.ChatCompletionContentPart[] = [] + + for (const c of toolResult.content) { + if (c.type === 'textBlock') { + textParts.push(c.text) + } else if (c.type === 'jsonBlock') { + try { + textParts.push(JSON.stringify(c.json)) + } catch (error: unknown) { + if (error instanceof Error) { + const dataPreview = + typeof c.json === 'object' && c.json !== null + ? `object with keys: ${Object.keys(c.json).slice(0, 5).join(', ')}` + : typeof c.json + textParts.push(`[JSON Serialization Error: ${error.message}. Data type: ${dataPreview}]`) + } + } + } else if (c.type === 'imageBlock') { + const formatted = formatImageContentPart(c as ImageBlock) + if (formatted) imageParts.push(formatted) + } else if (c.type === 'documentBlock') { + logger.warn('block_type= | documents not supported in openai tool results, skipping') + } else if (c.type === 'videoBlock') { + logger.warn('block_type= | videos not supported in openai tool results, skipping') + } + } + + return [textParts.join(''), imageParts] +} + +/** + * Maps a Chat Completions streaming chunk to one or more SDK events. Mutates + * `state` and `activeToolCalls` as a side effect. + * + * @internal + */ +export function mapChatChunkToEvents( + chunk: { choices: unknown[] }, + state: ChatStreamState, + activeToolCalls: Map +): ModelStreamEvent[] { + const events: ModelStreamEvent[] = [] + + if (!chunk.choices || chunk.choices.length === 0) return events + + const choice = chunk.choices[0] + if (!choice || typeof choice !== 'object') { + logger.warn(`choice=<${choice}> | invalid choice format in openai chunk`) + return events + } + + const typedChoice = choice as OpenAIChatChoice + if (!typedChoice.delta && !typedChoice.finish_reason) return events + + const delta = typedChoice.delta + + if (delta?.role && !state.messageStarted) { + state.messageStarted = true + events.push({ type: 'modelMessageStartEvent', role: delta.role as 'user' | 'assistant' }) + } + + if (delta?.content && delta.content.length > 0) { + if (!state.textContentBlockStarted) { + state.textContentBlockStarted = true + events.push({ type: 'modelContentBlockStartEvent' }) + } + events.push({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: delta.content }, + }) + } + + if (delta?.tool_calls && delta.tool_calls.length > 0) { + for (const toolCall of delta.tool_calls) { + if (toolCall.index === undefined || typeof toolCall.index !== 'number') { + logger.warn(`tool_call=<${JSON.stringify(toolCall)}> | received tool call with invalid index`) + continue + } + + if (toolCall.id && toolCall.function?.name) { + events.push({ + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', name: toolCall.function.name, toolUseId: toolCall.id }, + }) + activeToolCalls.set(toolCall.index, true) + } + + if (toolCall.function?.arguments) { + events.push({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: toolCall.function.arguments }, + }) + } + } + } + + if (typedChoice.finish_reason) { + if (state.textContentBlockStarted) { + events.push({ type: 'modelContentBlockStopEvent' }) + state.textContentBlockStarted = false + } + + for (const [index] of activeToolCalls) { + events.push({ type: 'modelContentBlockStopEvent' }) + activeToolCalls.delete(index) + } + + const stopReasonMap: Record = { + stop: 'endTurn', + tool_calls: 'toolUse', + length: 'maxTokens', + content_filter: 'contentFiltered', + } + const stopReason: StopReason = stopReasonMap[typedChoice.finish_reason] ?? snakeToCamel(typedChoice.finish_reason) + if (!stopReasonMap[typedChoice.finish_reason]) { + logger.warn( + `finish_reason=<${typedChoice.finish_reason}>, fallback=<${stopReason}> | unknown openai stop reason, using camelCase conversion as fallback` + ) + } + + events.push({ type: 'modelMessageStopEvent', stopReason }) + } + + return events +} + +function snakeToCamel(str: string): string { + return str.replace(/_([a-z])/g, (_, letter) => letter.toUpperCase()) +} diff --git a/strands-ts/src/models/openai/errors.ts b/strands-ts/src/models/openai/errors.ts new file mode 100644 index 0000000000..7bc0af5090 --- /dev/null +++ b/strands-ts/src/models/openai/errors.ts @@ -0,0 +1,52 @@ +/** + * Shared error classification for the OpenAI model provider. + * + * @internal + */ + +/** + * Error message patterns that indicate context window overflow. + * + * @see https://platform.openai.com/docs/guides/error-codes + */ +const CONTEXT_WINDOW_OVERFLOW_PATTERNS = [ + 'maximum context length', + 'context_length_exceeded', + 'too many tokens', + 'context length', + 'Input is too long for requested model', + 'input length and `max_tokens` exceed context limit', + 'too many total text bytes', +] + +/** + * Error patterns that indicate rate limiting. + * + * @see https://platform.openai.com/docs/guides/error-codes + */ +const RATE_LIMIT_PATTERNS = ['rate_limit_exceeded', 'rate limit', 'too many requests'] + +export type OpenAIErrorKind = 'contextOverflow' | 'throttling' + +/** + * Classifies an OpenAI SDK error. + * + * @internal + */ +export function classifyOpenAIError(err: Error & { status?: number; code?: string }): OpenAIErrorKind | undefined { + const message = err.message?.toLowerCase() ?? '' + const code = err.code?.toLowerCase() ?? '' + + if (err.status === 429 || code === 'rate_limit_exceeded' || RATE_LIMIT_PATTERNS.some((p) => message.includes(p))) { + return 'throttling' + } + + if ( + code === 'context_length_exceeded' || + CONTEXT_WINDOW_OVERFLOW_PATTERNS.some((pattern) => message.includes(pattern.toLowerCase())) + ) { + return 'contextOverflow' + } + + return undefined +} diff --git a/strands-ts/src/models/openai/formatting.ts b/strands-ts/src/models/openai/formatting.ts new file mode 100644 index 0000000000..f761a71b27 --- /dev/null +++ b/strands-ts/src/models/openai/formatting.ts @@ -0,0 +1,42 @@ +/** + * Shared media formatting helpers for OpenAI adapters. + * + * @internal + */ + +import type { ImageBlock } from '../../types/media.js' +import { encodeBase64 } from '../../types/media.js' +import { toMimeType } from '../../mime.js' +import { logger } from '../../logging/logger.js' + +/** + * Logs a warning for each key in `params` that is managed by the provider and + * would be overwritten at request time. Fires at config time so callers notice + * before sending a request. + */ +export function warnManagedParams(params: Record | undefined, managed: ReadonlySet): void { + if (!params) return + for (const key of Object.keys(params)) { + if (managed.has(key)) { + logger.warn( + `params_key=<${key}> | '${key}' is managed by the provider and will be ignored in params — use the dedicated config property instead` + ) + } + } +} + +/** + * Builds a `data:;base64,` URL for an image block. + * Returns `undefined` for unsupported source types. + */ +export function formatImageDataUrl(imageBlock: ImageBlock): string | undefined { + if (imageBlock.source.type === 'imageSourceBytes') { + const base64 = encodeBase64(imageBlock.source.bytes) + const mimeType = toMimeType(imageBlock.format) || `image/${imageBlock.format}` + return `data:${mimeType};base64,${base64}` + } + if (imageBlock.source.type === 'imageSourceUrl') { + return imageBlock.source.url + } + return undefined +} diff --git a/strands-ts/src/models/openai/index.ts b/strands-ts/src/models/openai/index.ts new file mode 100644 index 0000000000..a2c3b29406 --- /dev/null +++ b/strands-ts/src/models/openai/index.ts @@ -0,0 +1,26 @@ +/** + * OpenAI model provider. + * + * Defaults to the Responses API. Pass `api: 'chat'` to use Chat Completions. + * + * @example + * ```typescript + * import { OpenAIModel } from '@strands-agents/sdk/models/openai' + * + * // Responses API (default) + * const model = new OpenAIModel({ modelId: 'gpt-5.4', apiKey: 'sk-...' }) + * + * // Chat Completions + * const model = new OpenAIModel({ api: 'chat', modelId: 'gpt-5.4', apiKey: 'sk-...' }) + * ``` + */ + +export { OpenAIModel } from './model.js' +export type { + OpenAIApi, + OpenAIChatConfig, + OpenAIModelConfig, + OpenAIModelOptions, + OpenAIResponsesConfig, +} from './types.js' +export type { BedrockMantleConfig } from './mantle.js' diff --git a/strands-ts/src/models/openai/mantle.ts b/strands-ts/src/models/openai/mantle.ts new file mode 100644 index 0000000000..8e86f1a1f0 --- /dev/null +++ b/strands-ts/src/models/openai/mantle.ts @@ -0,0 +1,142 @@ +/** + * Internal helpers for routing an {@link OpenAIModel} through Amazon Bedrock's + * OpenAI-compatible "Mantle" endpoint. + * + * Converts a {@link BedrockMantleConfig} into the `baseURL` and `apiKey` the + * OpenAI SDK consumes. Tokens are minted on demand via + * `@aws/bedrock-token-generator` so long-running agents survive the bearer + * token's maximum lifetime. + * + * `@aws/bedrock-token-generator` is declared as an optional peer dependency, so + * the import is lazy: it happens the first time the OpenAI client's async + * `apiKey` setter is invoked. + */ + +import type { AwsCredentialIdentity, AwsCredentialIdentityProvider } from '@smithy/types' + +const MANTLE_DOCS_URL = 'https://docs.aws.amazon.com/bedrock/latest/userguide/inference-openai.html' + +/** + * Async function that returns a freshly minted Bedrock Mantle bearer token. + * Matches the shape returned by `@aws/bedrock-token-generator`'s + * `getTokenProvider`. + * + * @internal + */ +export type TokenProvider = () => Promise + +/** + * Config for routing an OpenAI-compatible client through Amazon Bedrock's + * Mantle endpoint. + * + * When supplied to `OpenAIModel`, this config derives the OpenAI client's + * `baseURL` and `apiKey`. It cannot be combined with a pre-built `client`, + * a top-level `apiKey`, or `clientConfig.baseURL` / `clientConfig.apiKey`, + * since those are derived from this config. + */ +export interface BedrockMantleConfig { + /** + * AWS region hosting the Bedrock Mantle endpoint. If omitted, resolved from + * the `AWS_REGION` or `AWS_DEFAULT_REGION` environment variable. An error is + * thrown if none resolve. + */ + region?: string + + /** + * AWS credentials forwarded to the bearer token generator. Accepts either a + * static credential identity or a credential provider function (e.g. the + * result of `fromNodeProviderChain()` from `@aws-sdk/credential-providers`). + * When omitted, the token generator resolves credentials from the standard + * AWS credential chain. + */ + credentials?: AwsCredentialIdentity | AwsCredentialIdentityProvider + + /** + * Bearer token lifetime in seconds, forwarded to the token generator. + * Capped at 12 hours by AWS. When omitted, the generator's default applies. + * @see https://docs.aws.amazon.com/bedrock/latest/userguide/inference-openai.html + */ + expiresInSeconds?: number +} + +/** + * Resolves the AWS region for Mantle, preferring explicit config and falling + * back to the standard AWS env vars. + * + * @internal + */ +export function resolveMantleRegion(config: BedrockMantleConfig): string { + if (config.region) { + return config.region + } + + const envRegion = globalThis?.process?.env?.AWS_REGION || globalThis?.process?.env?.AWS_DEFAULT_REGION + if (envRegion) { + return envRegion + } + + throw new Error( + "could not resolve an AWS region for Bedrock Mantle. Pass 'region' in " + + 'bedrockMantleConfig or set AWS_REGION in the environment. ' + + `See ${MANTLE_DOCS_URL} for supported regions.` + ) +} + +/** + * Builds the Mantle base URL for a region. + * + * @internal + */ +export function bedrockMantleBaseUrl(region: string): string { + return `https://bedrock-mantle.${region}.api.aws/v1` +} + +/** + * Builds an async `apiKey` setter (matching the OpenAI SDK's `ApiKeySetter` + * signature) that mints a fresh bearer token on every request. + * + * The `@aws/bedrock-token-generator` package is loaded lazily on first use so + * applications that never touch the Mantle pathway don't need it installed. + * + * @internal + */ +export function createMantleApiKeySetter(config: BedrockMantleConfig, region: string): () => Promise { + let tokenProviderPromise: Promise | null = null + + const initProvider = async (): Promise => { + const { getTokenProvider } = await loadTokenGenerator() + return getTokenProvider({ + region, + ...(config.credentials !== undefined ? { credentials: config.credentials } : {}), + ...(config.expiresInSeconds !== undefined ? { expiresInSeconds: config.expiresInSeconds } : {}), + }) + } + + return async (): Promise => { + if (tokenProviderPromise === null) { + tokenProviderPromise = initProvider() + } + const provideToken = await tokenProviderPromise + try { + return await provideToken() + } catch (cause) { + throw new Error( + `failed to mint Bedrock Mantle bearer token for region '${region}' | ` + + 'verify your AWS credentials and network connectivity', + { cause } + ) + } + } +} + +async function loadTokenGenerator(): Promise { + try { + return await import('@aws/bedrock-token-generator') + } catch (cause) { + throw new Error( + "bedrockMantleConfig requires the '@aws/bedrock-token-generator' package | " + + "install it with: npm install '@aws/bedrock-token-generator'", + { cause } + ) + } +} diff --git a/strands-ts/src/models/openai/model.ts b/strands-ts/src/models/openai/model.ts new file mode 100644 index 0000000000..1ae5c52981 --- /dev/null +++ b/strands-ts/src/models/openai/model.ts @@ -0,0 +1,298 @@ +/** + * OpenAI model provider implementation. + * + * Supports both the Responses API (default) and the Chat Completions API. + * Selected via the `api` option at construction time. + * + * @see https://platform.openai.com/docs/api-reference/responses + * @see https://platform.openai.com/docs/api-reference/chat + */ + +import OpenAI from 'openai' +import type { ResponseStreamEvent } from 'openai/resources/responses/responses' +import { Model, resolveConfigMetadata } from '../model.js' +import type { StreamOptions } from '../model.js' +import type { Message } from '../../types/messages.js' +import type { ModelStreamEvent } from '../streaming.js' +import { ContextWindowOverflowError, ModelThrottledError } from '../../errors.js' +import { logger } from '../../logging/logger.js' +import { warnOnce } from '../../logging/warn-once.js' +import { MODEL_DEFAULTS, defaultModelWarningMessage } from '../defaults.js' +import { bedrockMantleBaseUrl, createMantleApiKeySetter, resolveMantleRegion } from './mantle.js' +import { classifyOpenAIError } from './errors.js' +import { formatChatRequest, mapChatChunkToEvents, warnManagedParams as warnChatManagedParams } from './chat-adapter.js' +import { + createResponsesStreamState, + finalizeResponsesStream, + formatResponsesRequest, + mapResponsesEventToSDK, + warnManagedParams as warnResponsesManagedParams, +} from './responses-adapter.js' +import type { + ChatStreamState, + OpenAIApi, + OpenAIChatConfig, + OpenAIModelConfig, + OpenAIModelOptions, + OpenAIResponsesConfig, +} from './types.js' + +/** + * OpenAI model provider. + * + * Defaults to the Responses API. Pass `api: 'chat'` to use Chat Completions. + * The `api` field is construction-only — it cannot be changed via + * {@link OpenAIModel.updateConfig}. + * + * @example + * ```typescript + * // Responses API (default) + * const model = new OpenAIModel({ modelId: 'gpt-5.4', apiKey: 'sk-...' }) + * ``` + * + * @example + * ```typescript + * // Chat Completions + * const model = new OpenAIModel({ api: 'chat', modelId: 'gpt-5.4', apiKey: 'sk-...' }) + * ``` + * + * @example + * ```typescript + * // Responses API with built-in web search + * const model = new OpenAIModel({ + * modelId: 'gpt-5.4', + * params: { tools: [{ type: 'web_search' }] }, + * }) + * ``` + */ +export class OpenAIModel extends Model { + private readonly _api: OpenAIApi + private _config: OpenAIModelConfig + private _client: OpenAI + + constructor(options: OpenAIModelOptions) { + super() + const { apiKey, client, clientConfig, bedrockMantleConfig, api = 'responses', ...modelConfig } = options + + if (api !== 'chat' && api !== 'responses') { + throw new Error(`Unsupported OpenAI API: '${api}'. Supported values: 'chat', 'responses'`) + } + + this._api = api + // `stateful` only exists on the responses branch of the discriminated union. + // Storing as the merged OpenAIModelConfig matches what `getConfig` returns. + this._config = modelConfig + + if (modelConfig.modelId === undefined) { + warnOnce(logger, defaultModelWarningMessage(MODEL_DEFAULTS.openai.modelId)) + } + + if (api === 'responses') { + warnResponsesManagedParams(modelConfig.params) + } else { + warnChatManagedParams(modelConfig.params) + } + + if (bedrockMantleConfig && client) { + throw new Error("'bedrockMantleConfig' cannot be combined with a pre-built 'client'.") + } + + if (client) { + this._client = client + } else if (bedrockMantleConfig) { + this._client = buildMantleClient(bedrockMantleConfig, apiKey, clientConfig) + } else { + const hasEnvKey = + typeof process !== 'undefined' && typeof process.env !== 'undefined' && process.env.OPENAI_API_KEY + if (!apiKey && !hasEnvKey) { + throw new Error( + "OpenAI API key is required. Provide it via the 'apiKey' option (string or function) or set the OPENAI_API_KEY environment variable." + ) + } + this._client = new OpenAI({ + ...(apiKey ? { apiKey } : {}), + ...clientConfig, + }) + } + } + + /** + * The OpenAI API mode this model operates in (`'chat'` or `'responses'`). + * Set at construction and immutable; exposed for debugging and serialization. + */ + get api(): OpenAIApi { + return this._api + } + + /** + * Whether this model manages conversation state server-side. + * + * `true` only for `api: 'responses'` with `stateful === true`. Chat Completions + * is always stateless, and Responses defaults to stateless. + */ + override get stateful(): boolean { + return this._api === 'responses' && this._config.stateful === true + } + + /** + * Updates the model configuration. + * + * `api` and `stateful` are construction-only — if present in `modelConfig`, + * they are stripped with a warning. Changing either at runtime would + * invalidate the invariants the agent builds on top of `stateful` (message + * history management, `previous_response_id` chaining). + */ + updateConfig(modelConfig: OpenAIModelConfig & { api?: OpenAIApi }): void { + const { api, stateful, ...rest } = modelConfig + if (api !== undefined) { + logger.warn(`api=<${api}> | 'api' is construction-only and cannot be changed via updateConfig — ignoring`) + } + if (stateful !== undefined) { + logger.warn( + `stateful=<${stateful}> | 'stateful' is construction-only and cannot be changed via updateConfig — ignoring` + ) + } + + if (this._api === 'responses') { + warnResponsesManagedParams(rest.params) + } else { + warnChatManagedParams(rest.params) + } + + this._config = { ...this._config, ...rest } + } + + getConfig(): OpenAIModelConfig { + return resolveConfigMetadata(this._config, this._config.modelId ?? MODEL_DEFAULTS.openai.modelId) + } + + async *stream(messages: Message[], options?: StreamOptions): AsyncIterable { + if (!messages || messages.length === 0) { + throw new Error('At least one message is required') + } + + if (this._api === 'chat') { + yield* this._streamChat(messages, options) + } else { + yield* this._streamResponses(messages, options) + } + } + + private async *_streamChat(messages: Message[], options?: StreamOptions): AsyncIterable { + try { + const request = formatChatRequest(this._config as OpenAIChatConfig, messages, options) + const stream = await this._client.chat.completions.create(request) + + const streamState: ChatStreamState = { + messageStarted: false, + textContentBlockStarted: false, + } + const activeToolCalls = new Map() + + let bufferedUsage: { + type: 'modelMetadataEvent' + usage: { inputTokens: number; outputTokens: number; totalTokens: number } + } | null = null + + for await (const chunk of stream) { + if (!chunk.choices || chunk.choices.length === 0) { + if (chunk.usage) { + bufferedUsage = { + type: 'modelMetadataEvent', + usage: { + inputTokens: chunk.usage.prompt_tokens ?? 0, + outputTokens: chunk.usage.completion_tokens ?? 0, + totalTokens: chunk.usage.total_tokens ?? 0, + }, + } + } + continue + } + + const events = mapChatChunkToEvents(chunk, streamState, activeToolCalls) + for (const event of events) { + if (event.type === 'modelMessageStopEvent' && bufferedUsage) { + yield bufferedUsage + bufferedUsage = null + } + yield event + } + } + + if (bufferedUsage) { + yield bufferedUsage + } + } catch (error) { + throw this._rewrapError(error) + } + } + + private async *_streamResponses(messages: Message[], options?: StreamOptions): AsyncIterable { + try { + const request = formatResponsesRequest(this._config as OpenAIResponsesConfig, messages, options, this.stateful) + const stream = await this._client.responses.create(request) + + const state = createResponsesStreamState() + + for await (const event of stream as AsyncIterable) { + for (const sdkEvent of mapResponsesEventToSDK(event, state, this.stateful, options?.modelState)) { + yield sdkEvent + } + } + + for (const sdkEvent of finalizeResponsesStream(state)) { + yield sdkEvent + } + } catch (error) { + throw this._rewrapError(error) + } + } + + private _rewrapError(error: unknown): unknown { + const err = error as Error & { status?: number; code?: string } + const kind = classifyOpenAIError(err) + + if (kind === 'throttling') { + const message = err.message ?? 'Request was throttled by the model provider' + logger.debug(`throttled | error_message=<${message}>`) + return new ModelThrottledError(message, { cause: err }) + } + + if (kind === 'contextOverflow') { + return new ContextWindowOverflowError(err.message) + } + + return error + } +} + +function buildMantleClient( + bedrockMantleConfig: NonNullable, + apiKey: OpenAIModelOptions['apiKey'], + clientConfig: OpenAIModelOptions['clientConfig'] +): OpenAI { + if (apiKey !== undefined) { + throw new Error( + "'apiKey' cannot be combined with 'bedrockMantleConfig'; the API key is derived from the Mantle config automatically." + ) + } + + const conflicting: string[] = [] + if (clientConfig?.apiKey !== undefined) conflicting.push('apiKey') + if (clientConfig?.baseURL !== undefined) conflicting.push('baseURL') + if (conflicting.length > 0) { + throw new Error( + `clientConfig must not contain ${conflicting.join(', ')} when bedrockMantleConfig is set; ` + + 'these are derived from the Mantle config automatically.' + ) + } + + // Resolve the region eagerly so missing-region configuration fails fast. + const region = resolveMantleRegion(bedrockMantleConfig) + + return new OpenAI({ + ...clientConfig, + baseURL: bedrockMantleBaseUrl(region), + apiKey: createMantleApiKeySetter(bedrockMantleConfig, region), + }) +} diff --git a/strands-ts/src/models/openai/responses-adapter.ts b/strands-ts/src/models/openai/responses-adapter.ts new file mode 100644 index 0000000000..7833a1ca45 --- /dev/null +++ b/strands-ts/src/models/openai/responses-adapter.ts @@ -0,0 +1,547 @@ +/** + * Responses API adapter for the OpenAI model provider. + * + * Built-in tool support status: + * | Tool | Support | + * |-------------------|----------------------------------------------------------| + * | web_search | Full: includes URL citations | + * | file_search | Partial: works but file citation annotations not emitted | + * | code_interpreter | Partial: works but executed code/stdout not surfaced | + * | mcp | Partial: works but approval flow not supported | + * | shell | Partial: container mode only | + * | image_generation | Not supported | + * + * @internal + */ + +import type { + ResponseStreamEvent, + ResponseInputItem, + ResponseFunctionToolCall, + ResponseFunctionCallOutputItem, + ResponseCreateParamsStreaming, +} from 'openai/resources/responses/responses' +import type { Message, StopReason, ToolResultBlock } from '../../types/messages.js' +import type { ImageBlock, DocumentBlock } from '../../types/media.js' +import { encodeBase64 } from '../../types/media.js' +import { toMimeType } from '../../mime.js' +import type { StateStore } from '../../state-store.js' +import type { ModelStreamEvent } from '../streaming.js' +import type { StreamOptions } from '../model.js' +import { logger } from '../../logging/logger.js' +import { MODEL_DEFAULTS } from '../defaults.js' +import { formatImageDataUrl, warnManagedParams as warnManagedParamsShared } from './formatting.js' +import type { OpenAIResponsesConfig } from './types.js' + +export const DEFAULT_RESPONSES_MODEL_ID = MODEL_DEFAULTS.openai.modelId + +const MANAGED_PARAMS: ReadonlySet = new Set(['model', 'input', 'stream', 'store']) + +/** + * Logs a warning for each responses-managed key present in `params`. + * + * @internal + */ +export function warnManagedParams(params: Record | undefined): void { + warnManagedParamsShared(params, MANAGED_PARAMS) +} + +/** + * Builds a Responses API streaming request body. + * + * @internal + */ +export function formatResponsesRequest( + config: OpenAIResponsesConfig, + messages: Message[], + options: StreamOptions | undefined, + stateful: boolean +): ResponseCreateParamsStreaming { + const input = formatResponsesMessages(messages) + + // User `params` are spread first so provider-managed fields (asserted + // required by `ResponseCreateParamsStreaming` below) always win. The + // managed-params warning fires at config time to surface the collision. + const request = { + ...(config.params ?? {}), + model: config.modelId ?? DEFAULT_RESPONSES_MODEL_ID, + input, + stream: true as const, + store: stateful, + } as ResponseCreateParamsStreaming + + if (stateful) { + const responseId = options?.modelState?.get('responseId') as string | undefined + if (responseId) { + request.previous_response_id = responseId + } + } + + if (options?.systemPrompt !== undefined) { + if (typeof options.systemPrompt === 'string') { + request.instructions = options.systemPrompt + } else if (Array.isArray(options.systemPrompt)) { + const texts: string[] = [] + for (const block of options.systemPrompt) { + if (block.type === 'textBlock') { + texts.push(block.text) + } + } + if (texts.length > 0) { + request.instructions = texts.join('') + } + } + } + + if (options?.toolSpecs && options.toolSpecs.length > 0) { + const existingTools = request.tools ?? [] + request.tools = [ + ...existingTools, + ...options.toolSpecs.map((spec) => ({ + type: 'function' as const, + name: spec.name, + description: spec.description ?? '', + parameters: (spec.inputSchema ?? {}) as Record, + // `null` defers to the OpenAI server default. The SDK's typed + // contract requires a value; omitting it (as the Python SDK does) + // is not an option here. + strict: null, + })), + ] + + if (options.toolChoice) { + if ('auto' in options.toolChoice) { + request.tool_choice = 'auto' + } else if ('any' in options.toolChoice) { + request.tool_choice = 'required' + } else if ('tool' in options.toolChoice) { + request.tool_choice = { type: 'function', name: options.toolChoice.tool.name } + } + } + } + + if (config.temperature !== undefined) request.temperature = config.temperature + if (config.maxTokens !== undefined) request.max_output_tokens = config.maxTokens + if (config.topP !== undefined) request.top_p = config.topP + + return request +} + +/** + * Formats SDK messages into Responses API input items. + * + * Per message, content blocks are split into three buckets: + * - Text/media → grouped in `{ role, content: [...] }` + * - Tool calls → separate `{ type: 'function_call', ... }` items + * - Tool results → separate `{ type: 'function_call_output', ... }` items + */ +function formatResponsesMessages(messages: Message[]): ResponseInputItem[] { + const input: ResponseInputItem[] = [] + + for (const message of messages) { + const role = message.role === 'assistant' ? 'assistant' : 'user' + const contentItems: Array> = [] + const toolCallItems: ResponseInputItem[] = [] + const toolResultItems: ResponseInputItem[] = [] + + for (const block of message.content) { + switch (block.type) { + case 'textBlock': { + if (role === 'user') { + contentItems.push({ type: 'input_text', text: block.text }) + } else { + contentItems.push({ type: 'output_text', text: block.text }) + } + break + } + + case 'imageBlock': { + const formatted = formatImageInput(block as ImageBlock) + if (formatted) contentItems.push(formatted) + break + } + + case 'documentBlock': { + const formatted = formatDocumentInput(block as DocumentBlock) + if (formatted) contentItems.push(formatted) + break + } + + case 'citationsBlock': { + const citBlock = block as { content: Array<{ text: string }> } + for (const c of citBlock.content) { + contentItems.push({ type: 'output_text', text: c.text }) + } + break + } + + case 'toolUseBlock': { + const toolBlock = block as { name: string; toolUseId: string; input: unknown } + const call: ResponseFunctionToolCall = { + type: 'function_call', + call_id: toolBlock.toolUseId, + name: toolBlock.name, + arguments: JSON.stringify(toolBlock.input), + } + toolCallItems.push(call) + break + } + + case 'toolResultBlock': { + const resultBlock = block as ToolResultBlock + const result: ResponseInputItem.FunctionCallOutput = { + type: 'function_call_output', + call_id: resultBlock.toolUseId, + output: formatToolResultOutput(resultBlock), + } + toolResultItems.push(result) + break + } + + case 'reasoningBlock': { + logger.warn( + 'block_type= | reasoning content is not yet supported in multi-turn conversations with the responses api' + ) + break + } + + default: { + logger.warn( + `block_type=<${block.type}> | unsupported content type in responses api message formatting | skipping` + ) + } + } + } + + // Cast is needed because assistant messages here use `output_text` content + // blocks, which the SDK's input types model as `ResponseOutputMessage` — + // a response-shaped type that requires `id`/`status`/`annotations`. The API + // accepts these fields as omitted on input, but the SDK types don't reflect that. + if (contentItems.length > 0) { + input.push({ + role, + content: contentItems, + } as unknown as ResponseInputItem) + } + + input.push(...toolCallItems) + input.push(...toolResultItems) + } + + return input +} + +/** + * Builds a Responses API `function_call_output.output` value from a SDK + * `toolResultBlock`. Returns a plain string for text-only results (joined with + * newlines) or the content-item array shape when the result carries image or + * document data. + */ +function formatToolResultOutput(resultBlock: ToolResultBlock): string | ResponseFunctionCallOutputItem[] { + const parts: ResponseFunctionCallOutputItem[] = [] + const texts: string[] = [] + let hasMedia = false + + for (const c of resultBlock.content) { + switch (c.type) { + case 'textBlock': + texts.push(c.text) + parts.push({ type: 'input_text', text: c.text }) + break + case 'jsonBlock': { + const jsonBlock = c as { json: unknown } + let text: string + try { + text = JSON.stringify(jsonBlock.json) + } catch { + text = '[JSON serialization error]' + } + texts.push(text) + parts.push({ type: 'input_text', text }) + break + } + case 'imageBlock': { + const url = formatImageDataUrl(c as ImageBlock) + if (url) { + hasMedia = true + parts.push({ type: 'input_image', image_url: url }) + } + break + } + case 'documentBlock': { + const docBlock = c as DocumentBlock + if (docBlock.source.type === 'documentSourceBytes') { + const base64 = encodeBase64(docBlock.source.bytes) + const mimeType = toMimeType(docBlock.format) || `application/${docBlock.format}` + hasMedia = true + parts.push({ + type: 'input_file', + file_data: `data:${mimeType};base64,${base64}`, + filename: docBlock.name, + }) + } else { + logger.warn( + `source_type=<${docBlock.source.type}> | only byte source documents supported in responses api tool results` + ) + } + break + } + default: + logger.warn(`block_type=<${c.type}> | unsupported tool result content type for responses api`) + } + } + + if (hasMedia) return parts + + // Text-only: collapse to a single string to match the API's simpler shape. + const text = texts.join('\n') + if (resultBlock.status === 'error') { + return `[ERROR] ${text}` + } + return text +} + +function formatImageInput(imageBlock: ImageBlock): Record | undefined { + const url = formatImageDataUrl(imageBlock) + if (!url) return undefined + return { type: 'input_image', image_url: url } +} + +function formatDocumentInput(docBlock: DocumentBlock): Record | undefined { + if (docBlock.source.type === 'documentSourceBytes') { + const base64 = encodeBase64(docBlock.source.bytes) + const mimeType = toMimeType(docBlock.format) || `application/${docBlock.format}` + return { + type: 'input_file', + file_data: `data:${mimeType};base64,${base64}`, + filename: docBlock.name, + } + } + logger.warn(`source_type=<${docBlock.source.type}> | only byte source documents supported in responses api`) + return undefined +} + +/** + * Internal stream state for the Responses adapter. Tracks the active content + * block kind so the adapter can emit stop/start events when content type + * switches (text ↔ reasoning ↔ citations). + * + * @internal + */ +export interface ResponsesStreamState { + dataType: string | null + toolCalls: Map + finalUsage: { inputTokens: number; outputTokens: number; totalTokens: number } | null + stopReason: StopReason +} + +/** + * Creates fresh stream state for a new Responses API stream. + * + * @internal + */ +export function createResponsesStreamState(): ResponsesStreamState { + return { + dataType: null, + toolCalls: new Map(), + finalUsage: null, + stopReason: 'endTurn', + } +} + +/** + * Maps a single Responses API stream event to zero or more SDK events. Mutates + * `state` and, when `stateful` is `true`, writes `responseId` into `modelState`. + * + * @internal + */ +export function mapResponsesEventToSDK( + event: ResponseStreamEvent, + state: ResponsesStreamState, + stateful: boolean, + modelState: StateStore | undefined +): ModelStreamEvent[] { + const events: ModelStreamEvent[] = [] + + switch (event.type) { + case 'response.created': { + if (stateful && modelState) { + modelState.set('responseId', event.response.id) + } + events.push({ type: 'modelMessageStartEvent', role: 'assistant' as const }) + break + } + + case 'response.output_text.delta': { + events.push(...switchContent('text', state.dataType)) + state.dataType = 'text' + events.push({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: event.delta }, + }) + break + } + + case 'response.reasoning_text.delta': + case 'response.reasoning_summary_text.delta': { + events.push(...switchContent('reasoning', state.dataType)) + state.dataType = 'reasoning' + events.push({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'reasoningContentDelta', text: event.delta }, + }) + break + } + + case 'response.output_text.annotation.added': { + // The SDK types `event.annotation` as `unknown` and doesn't export a + // named annotation union, so we narrow structurally on the fields we use. + const annotation = event.annotation as { type: string; url?: string; title?: string; cited_text?: string } + if (annotation.type === 'url_citation') { + // Close the in-flight text block before the citation delta. + // model.ts finalization picks ONE block kind per open block + // (citations wins over text), so text + citation in the same + // block drops the text on stop. Switching here forces a + // separate CitationsBlock, and the next text delta will open + // a fresh TextBlock. + events.push(...switchContent('citations', state.dataType)) + state.dataType = 'citations' + events.push({ + type: 'modelContentBlockDeltaEvent', + delta: { + type: 'citationsDelta', + citations: [ + { + location: { + type: 'web' as const, + url: annotation.url ?? '', + }, + source: annotation.url ?? '', + sourceContent: [], + title: annotation.title ?? '', + }, + ], + content: [{ text: annotation.cited_text ?? '' }], + }, + }) + } else { + logger.warn(`annotation_type=<${annotation.type}> | unsupported annotation type in responses api`) + } + break + } + + case 'response.output_item.added': { + if (event.item.type === 'function_call') { + // `id` is optional in the SDK type but load-bearing here: it keys + // subsequent argument delta/done events. Skip rather than collapse + // to an empty string, which would let distinct calls share a key. + const { id: itemId, call_id: callId, name } = event.item + if (!itemId) { + logger.warn(`call_id=<${callId}> name=<${name}> | function_call event missing item id — skipping`) + break + } + state.toolCalls.set(itemId, { name, arguments: '', callId, itemId }) + } + break + } + + case 'response.function_call_arguments.delta': { + const tc = state.toolCalls.get(event.item_id) + if (tc) { + tc.arguments += event.delta + } + break + } + + case 'response.function_call_arguments.done': { + const tc = state.toolCalls.get(event.item_id) + if (tc) { + tc.arguments = event.arguments + } + break + } + + case 'response.incomplete': { + const resp = event.response + if (resp.usage) { + state.finalUsage = { + inputTokens: resp.usage.input_tokens, + outputTokens: resp.usage.output_tokens, + totalTokens: resp.usage.total_tokens, + } + } + if (resp.incomplete_details?.reason === 'max_output_tokens') { + state.stopReason = 'maxTokens' + } + break + } + + case 'response.completed': { + const resp = event.response + if (resp.usage) { + state.finalUsage = { + inputTokens: resp.usage.input_tokens, + outputTokens: resp.usage.output_tokens, + totalTokens: resp.usage.total_tokens, + } + } + break + } + + default: + break + } + + return events +} + +/** + * Emits the terminal events for a Responses API stream: closes any open content + * block, flushes accumulated tool calls, emits usage metadata, and finishes + * with `modelMessageStopEvent`. + * + * @internal + */ +export function finalizeResponsesStream(state: ResponsesStreamState): ModelStreamEvent[] { + const events: ModelStreamEvent[] = [] + + if (state.dataType !== null) { + events.push({ type: 'modelContentBlockStopEvent' }) + } + + for (const [, tc] of state.toolCalls) { + events.push({ + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', name: tc.name, toolUseId: tc.callId }, + }) + events.push({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: tc.arguments }, + }) + events.push({ type: 'modelContentBlockStopEvent' }) + } + + let stopReason = state.stopReason + if (state.toolCalls.size > 0) { + stopReason = 'toolUse' + } + + if (state.finalUsage) { + events.push({ type: 'modelMetadataEvent', usage: state.finalUsage }) + } + + events.push({ type: 'modelMessageStopEvent', stopReason }) + + return events +} + +function switchContent(newType: string, prevType: string | null): ModelStreamEvent[] { + const events: ModelStreamEvent[] = [] + if (newType !== prevType) { + if (prevType !== null) { + events.push({ type: 'modelContentBlockStopEvent' }) + } + events.push({ type: 'modelContentBlockStartEvent' }) + } + return events +} diff --git a/strands-ts/src/models/openai/types.ts b/strands-ts/src/models/openai/types.ts new file mode 100644 index 0000000000..97a6a21bc9 --- /dev/null +++ b/strands-ts/src/models/openai/types.ts @@ -0,0 +1,155 @@ +/** + * Type definitions for the OpenAI model provider. + */ + +import type OpenAI from 'openai' +import type { ApiKeySetter } from 'openai/client' +import type { ClientOptions } from 'openai' +import type { BaseModelConfig } from '../model.js' +import type { BedrockMantleConfig } from './mantle.js' + +/** + * Supported OpenAI API modes. + * - `'chat'`: Chat Completions API (stateless) + * - `'responses'`: Responses API (optional server-managed conversation state via `stateful: true`) + * + * @see https://platform.openai.com/docs/api-reference/chat + * @see https://platform.openai.com/docs/api-reference/responses + */ +export type OpenAIApi = 'chat' | 'responses' + +/** + * Fields shared by both Chat Completions and Responses API configurations. + */ +interface OpenAIBaseConfig extends BaseModelConfig { + /** + * OpenAI model identifier (e.g., `gpt-5.4`, `gpt-5.4-mini`, `gpt-4o`). + * Defaults depend on the selected `api`. + */ + modelId?: string + + /** + * Controls randomness in generation. + */ + temperature?: number + + /** + * Maximum number of tokens to generate in the response. + */ + maxTokens?: number + + /** + * Controls diversity via nucleus sampling. + */ + topP?: number + + /** + * Additional parameters passed through to the OpenAI API for forward compatibility. + * + * Provider-managed fields cannot be overridden via `params` — use the dedicated + * config properties instead. A warning is logged at config time if any are present: + * - Chat Completions: `model`, `messages`, `stream`, `stream_options` + * - Responses API: `model`, `input`, `stream`, `store` + */ + params?: Record +} + +/** + * Configuration fields specific to the Chat Completions API. + */ +export interface OpenAIChatConfig extends OpenAIBaseConfig { + /** + * Reduces repetition of token sequences (-2.0 to 2.0). + * Chat Completions only. + */ + frequencyPenalty?: number + + /** + * Encourages the model to talk about new topics (-2.0 to 2.0). + * Chat Completions only. + */ + presencePenalty?: number +} + +/** + * Configuration fields specific to the Responses API. + */ +export interface OpenAIResponsesConfig extends OpenAIBaseConfig { + /** + * When `true`, the server manages conversation state: the request sets + * `store: true` and chains turns via `previous_response_id`, the Agent + * clears its local message history after each invocation, and a + * `conversationManager` cannot be supplied. Defaults to `false` — the + * Responses API is used in stateless mode, where the full message history + * is sent on every turn. + */ + stateful?: boolean +} + +/** + * Runtime configuration shape returned by {@link OpenAIModel.getConfig}. + * + * Shared fields are required-shaped (still optional as per `BaseModelConfig`), and + * api-specific fields are optional because this is a merged view — callers cannot + * narrow on `api` from the returned config. + */ +export interface OpenAIModelConfig extends OpenAIBaseConfig { + frequencyPenalty?: number + presencePenalty?: number + stateful?: boolean +} + +interface OpenAIClientOptions { + /** + * OpenAI API key (falls back to `OPENAI_API_KEY` environment variable). + * + * Accepts either a static string or an async function that resolves to a string. + * When a function is provided, it is invoked before each request. + */ + apiKey?: string | ApiKeySetter + + /** + * Pre-configured OpenAI client instance. If provided, this client will be used + * instead of creating a new one. + */ + client?: OpenAI + + /** + * Additional OpenAI client configuration. Only used if `client` is not provided. + */ + clientConfig?: ClientOptions + + /** + * Route requests through Amazon Bedrock's OpenAI-compatible "Mantle" + * endpoint. When set, the OpenAI client's `baseURL` and `apiKey` are derived + * from this config; the top-level `apiKey`, `clientConfig.apiKey`, and + * `clientConfig.baseURL` options must not be passed alongside it. Cannot be + * combined with a pre-built `client`. Requires the optional peer dependency + * `@aws/bedrock-token-generator`. + * @see https://docs.aws.amazon.com/bedrock/latest/userguide/inference-openai.html + */ + bedrockMantleConfig?: BedrockMantleConfig +} + +/** + * Options for constructing an {@link OpenAIModel}. + * + * Discriminated on `api` so that selecting `'chat'` type-narrows to expose + * `frequencyPenalty` / `presencePenalty`, and selecting `'responses'` (or + * omitting `api`) narrows to expose `stateful`. + * + * `api` is construction-only: it cannot be changed via {@link OpenAIModel.updateConfig}. + */ +export type OpenAIModelOptions = + | ({ api?: 'responses' } & OpenAIResponsesConfig & OpenAIClientOptions) + | ({ api: 'chat' } & OpenAIChatConfig & OpenAIClientOptions) + +/** + * Internal stream state for the Chat Completions adapter. + * + * @internal + */ +export interface ChatStreamState { + messageStarted: boolean + textContentBlockStarted: boolean +} diff --git a/strands-ts/src/models/streaming.ts b/strands-ts/src/models/streaming.ts new file mode 100644 index 0000000000..893887d79f --- /dev/null +++ b/strands-ts/src/models/streaming.ts @@ -0,0 +1,567 @@ +import type { Role, StopReason } from '../types/messages.js' +import type { JSONValue } from '../types/json.js' +import type { Citation, CitationGeneratedContent } from '../types/citations.js' + +/** + * ModelStreamEvent types for Model interactions. + * + * This module follows a pattern where "Data" interfaces define the structure + * for objects, while corresponding classes extend those interfaces with additional + * functionality and type discrimination. + */ + +/** + * Union type representing all possible streaming events from a model provider. + * This is a discriminated union where each event has a unique type field. + * + * This allows for type-safe event handling using switch statements. + */ +export type ModelStreamEvent = + | ModelMessageStartEventData + | ModelContentBlockStartEventData + | ModelContentBlockDeltaEventData + | ModelContentBlockStopEventData + | ModelMessageStopEventData + | ModelMetadataEventData + | ModelRedactionEventData + +/** Set of all ModelStreamEvent type discriminators. */ +const modelStreamEventTypes: ReadonlySet = new Set([ + 'modelMessageStartEvent', + 'modelContentBlockStartEvent', + 'modelContentBlockDeltaEvent', + 'modelContentBlockStopEvent', + 'modelMessageStopEvent', + 'modelMetadataEvent', + 'modelRedactionEvent', +]) + +/** + * Type guard to check if an event with a type discriminator is a ModelStreamEvent. + * @param event - The event to check + * @returns true if the event is a ModelStreamEvent + */ +export function isModelStreamEvent(event: { type: string }): event is ModelStreamEvent { + return modelStreamEventTypes.has(event.type) +} + +/** + * Data for a message start event. + */ +export interface ModelMessageStartEventData { + /** + * Discriminator for message start events. + */ + type: 'modelMessageStartEvent' + + /** + * The role of the message being started. + */ + role: Role +} + +/** + * Event emitted when a new message starts in the stream. + */ +export class ModelMessageStartEvent implements ModelMessageStartEventData { + /** + * Discriminator for message start events. + */ + readonly type = 'modelMessageStartEvent' as const + + /** + * The role of the message being started. + */ + readonly role: Role + + constructor(data: ModelMessageStartEventData) { + this.role = data.role + } +} + +/** + * Data for a content block start event. + */ +export interface ModelContentBlockStartEventData { + /** + * Discriminator for content block start events. + */ + type: 'modelContentBlockStartEvent' + + /** + * Information about the content block being started. + * Only present for tool use blocks. + */ + start?: ContentBlockStart +} + +/** + * Event emitted when a new content block starts in the stream. + */ +export class ModelContentBlockStartEvent implements ModelContentBlockStartEventData { + /** + * Discriminator for content block start events. + */ + readonly type = 'modelContentBlockStartEvent' as const + + /** + * Information about the content block being started. + * Only present for tool use blocks. + */ + readonly start?: ContentBlockStart + + constructor(data: ModelContentBlockStartEventData) { + if (data.start !== undefined) { + this.start = data.start + } + } +} + +/** + * Data for a content block delta event. + */ +export interface ModelContentBlockDeltaEventData { + /** + * Discriminator for content block delta events. + */ + type: 'modelContentBlockDeltaEvent' + + /** + * The incremental content update. + */ + delta: ContentBlockDelta +} + +/** + * Event emitted when there is new content in a content block. + */ +export class ModelContentBlockDeltaEvent implements ModelContentBlockDeltaEventData { + /** + * Discriminator for content block delta events. + */ + readonly type = 'modelContentBlockDeltaEvent' as const + + /** + * Index of the content block being updated. + */ + readonly contentBlockIndex?: number + + /** + * The incremental content update. + */ + readonly delta: ContentBlockDelta + + constructor(data: ModelContentBlockDeltaEventData) { + this.delta = data.delta + } +} + +/** + * Data for a content block stop event. + */ +export interface ModelContentBlockStopEventData { + /** + * Discriminator for content block stop events. + */ + type: 'modelContentBlockStopEvent' +} + +/** + * Event emitted when a content block completes. + */ +export class ModelContentBlockStopEvent implements ModelContentBlockStopEventData { + /** + * Discriminator for content block stop events. + */ + readonly type = 'modelContentBlockStopEvent' as const + + constructor(_data: ModelContentBlockStopEventData) {} +} + +/** + * Data for a message stop event. + */ +export interface ModelMessageStopEventData { + /** + * Discriminator for message stop events. + */ + type: 'modelMessageStopEvent' + + /** + * Reason why generation stopped. + */ + stopReason: StopReason + + /** + * Additional provider-specific response fields. + */ + additionalModelResponseFields?: JSONValue +} + +/** + * Event emitted when the message completes. + */ +export class ModelMessageStopEvent implements ModelMessageStopEventData { + /** + * Discriminator for message stop events. + */ + readonly type = 'modelMessageStopEvent' as const + + /** + * Reason why generation stopped. + */ + readonly stopReason: StopReason + + /** + * Additional provider-specific response fields. + */ + readonly additionalModelResponseFields?: JSONValue + + constructor(data: ModelMessageStopEventData) { + this.stopReason = data.stopReason + if (data.additionalModelResponseFields !== undefined) { + this.additionalModelResponseFields = data.additionalModelResponseFields + } + } +} + +/** + * Data for a metadata event. + */ +export interface ModelMetadataEventData { + /** + * Discriminator for metadata events. + */ + type: 'modelMetadataEvent' + + /** + * Token usage information. + */ + usage?: Usage + + /** + * Performance metrics. + */ + metrics?: Metrics + + /** + * Trace information for observability. + */ + trace?: unknown +} + +/** + * Event containing metadata about the stream. + * Includes usage statistics, performance metrics, and trace information. + */ +export class ModelMetadataEvent implements ModelMetadataEventData { + /** + * Discriminator for metadata events. + */ + readonly type = 'modelMetadataEvent' as const + + /** + * Token usage information. + */ + readonly usage?: Usage + + /** + * Performance metrics. + */ + readonly metrics?: Metrics + + /** + * Trace information for observability. + */ + readonly trace?: unknown + + constructor(data: ModelMetadataEventData) { + if (data.usage !== undefined) { + this.usage = data.usage + } + if (data.metrics !== undefined) { + this.metrics = data.metrics + } + if (data.trace !== undefined) { + this.trace = data.trace + } + } +} + +/** + * Information about input content redaction. + * Does not include redactedContent since the original input is already available + * in the messages array from BeforeModelCallEvent. + */ +export interface RedactInputContent { + /** + * The content to replace the redacted input with. + */ + replaceContent: string +} + +/** + * Information about output content redaction. + * May include the original content if captured during streaming. + */ +export interface RedactOutputContent { + /** + * The original content that was blocked by guardrails. + * May not be available for all providers. + */ + redactedContent?: string + + /** + * The content to replace the redacted output with. + */ + replaceContent: string +} + +/** + * Data for a redact event. + * Emitted when guardrails block content and redaction is enabled. + */ +export interface ModelRedactionEventData { + /** + * Discriminator for redact events. + */ + type: 'modelRedactionEvent' + + /** + * Input redaction information (when input is blocked). + */ + inputRedaction?: RedactInputContent + + /** + * Output redaction information (when output is blocked). + */ + outputRedaction?: RedactOutputContent +} + +/** + * Event emitted when guardrails block content and trigger redaction. + */ +export class ModelRedactionEvent implements ModelRedactionEventData { + /** + * Discriminator for redact events. + */ + readonly type = 'modelRedactionEvent' as const + + /** + * Input redaction information (when input is blocked). + */ + readonly inputRedaction?: RedactInputContent + + /** + * Output redaction information (when output is blocked). + */ + readonly outputRedaction?: RedactOutputContent + + constructor(data: ModelRedactionEventData) { + if (data.inputRedaction !== undefined) { + this.inputRedaction = data.inputRedaction + } + if (data.outputRedaction !== undefined) { + this.outputRedaction = data.outputRedaction + } + } +} + +/** + * Information about a content block that is starting. + * Currently only represents tool use starts. + */ +export type ContentBlockStart = ToolUseStart + +/** + * Information about a tool use that is starting. + */ +export interface ToolUseStart { + /** + * Discriminator for tool use start. + */ + type: 'toolUseStart' + + /** + * The name of the tool being used. + */ + name: string + + /** + * Unique identifier for this tool use. + */ + toolUseId: string + + /** + * Reasoning signature from thinking models (e.g., Gemini). + * Must be preserved and sent back to the model for multi-turn tool use. + */ + reasoningSignature?: string +} + +/** + * A delta (incremental chunk) of content within a content block. + * Can be text, tool use input, or reasoning content. + * + * This is a discriminated union for type-safe delta handling. + */ +export type ContentBlockDelta = TextDelta | ToolUseInputDelta | ReasoningContentDelta | CitationsDelta + +/** + * Text delta within a content block. + * Represents incremental text content from the model. + */ +export interface TextDelta { + /** + * Discriminator for text delta. + */ + type: 'textDelta' + + /** + * Incremental text content. + */ + text: string +} + +/** + * Tool use input delta within a content block. + * Represents incremental tool input being generated. + */ +export interface ToolUseInputDelta { + /** + * Discriminator for tool use input delta. + */ + type: 'toolUseInputDelta' + + /** + * Partial JSON string representing the tool input. + */ + input: string +} + +/** + * Reasoning content delta within a content block. + * Represents incremental reasoning or thinking content. + */ +export interface ReasoningContentDelta { + /** + * Discriminator for reasoning delta. + */ + type: 'reasoningContentDelta' + + /** + * Incremental reasoning text. + */ + text?: string + + /** + * Incremental signature data. + */ + signature?: string + + /** + * Incremental redacted content data. + */ + redactedContent?: Uint8Array +} + +/** + * Citations content delta within a content block. + * Represents a citations content block from the model. + */ +export interface CitationsDelta { + /** + * Discriminator for citations content delta. + */ + type: 'citationsDelta' + + /** + * Array of citations linking generated content to source locations. + */ + citations: Citation[] + + /** + * The generated content associated with these citations. + */ + content: CitationGeneratedContent[] +} + +/** + * Token usage statistics for a model invocation. + * Tracks input, output, and total tokens, plus cache-related metrics. + */ +export interface Usage { + /** + * Number of tokens in the input (prompt). + */ + inputTokens: number + + /** + * Number of tokens in the output (completion). + */ + outputTokens: number + + /** + * Total number of tokens (input + output). + */ + totalTokens: number + + /** + * Number of input tokens read from cache. + * This can reduce latency and cost. + */ + cacheReadInputTokens?: number + + /** + * Number of input tokens written to cache. + * These tokens can be reused in future requests. + */ + cacheWriteInputTokens?: number +} + +/** + * Performance metrics for a model invocation. + */ +export interface Metrics { + /** + * Latency in milliseconds. + */ + latencyMs: number + + /** + * Time to first byte in milliseconds. + * Latency from sending the model request to receiving the first content chunk. + */ + timeToFirstByteMs?: number +} + +/** + * Accumulates token usage from a source into a target, mutating the target in place. + * + * @param target - Usage object to accumulate into + * @param source - Usage object to add from + */ +export function accumulateUsage(target: Usage, source: Usage): void { + target.inputTokens += source.inputTokens + target.outputTokens += source.outputTokens + target.totalTokens += source.totalTokens + if (source.cacheReadInputTokens !== undefined) { + target.cacheReadInputTokens = (target.cacheReadInputTokens ?? 0) + source.cacheReadInputTokens + } + if (source.cacheWriteInputTokens !== undefined) { + target.cacheWriteInputTokens = (target.cacheWriteInputTokens ?? 0) + source.cacheWriteInputTokens + } +} + +/** + * Creates a Usage object with all counters zeroed. + * + * @returns A Usage object with zeroed counters + */ +export function createEmptyUsage(): Usage { + return { + inputTokens: 0, + outputTokens: 0, + totalTokens: 0, + } +} diff --git a/strands-ts/src/models/vercel.ts b/strands-ts/src/models/vercel.ts new file mode 100644 index 0000000000..0a31b75533 --- /dev/null +++ b/strands-ts/src/models/vercel.ts @@ -0,0 +1,659 @@ +/** + * Vercel LanguageModelV3 model provider implementation. + * + * This module provides integration with any Vercel v3 compatible model provider, + * supporting streaming responses, tool use, and reasoning content. + * + * @see https://github.com/vercel/ai/tree/main/packages/provider/src/language-model/v3 + */ +import type { + LanguageModelV3, + LanguageModelV3CallOptions, + LanguageModelV3FilePart, + LanguageModelV3FinishReason, + LanguageModelV3FunctionTool, + LanguageModelV3Prompt, + LanguageModelV3ReasoningPart, + LanguageModelV3StreamPart, + LanguageModelV3TextPart, + LanguageModelV3ToolCallPart, + LanguageModelV3ToolChoice, + LanguageModelV3ToolResultOutput, + LanguageModelV3ToolResultPart, + LanguageModelV3Usage, +} from '@ai-sdk/provider' +import { APICallError } from '@ai-sdk/provider' +import type { SystemPrompt, StopReason } from '../types/messages.js' +import type { ToolChoice, ToolSpec } from '../tools/types.js' +import type { ModelStreamEvent, Usage } from './streaming.js' +import { Message, TextBlock, type ToolResultContent } from '../types/messages.js' +import { encodeBase64, ImageBlock, DocumentBlock, VideoBlock } from '../types/media.js' +import { Model, type BaseModelConfig, type StreamOptions } from './model.js' +import { + ModelContentBlockDeltaEvent, + ModelContentBlockStartEvent, + ModelContentBlockStopEvent, + ModelMessageStartEvent, + ModelMessageStopEvent, + ModelMetadataEvent, +} from './streaming.js' +import { ContextWindowOverflowError, ModelError, ModelThrottledError } from '../errors.js' +import { toMimeType } from '../mime.js' +import { logger } from '../logging/logger.js' + +/** + * Error message patterns that indicate context window overflow. + * These patterns are common across Vercel providers (Bedrock, OpenAI, Anthropic, etc.). + */ +const CONTEXT_WINDOW_OVERFLOW_PATTERNS = [ + 'too many tokens', + 'context length', + 'context_length_exceeded', + 'max_tokens exceeded', + 'too many total text bytes', + 'input is too long for requested model', + 'prompt is too long', + 'input too long', +] + +/** + * Call option fields from LanguageModelV3CallOptions that can be configured. + * Excludes prompt, tools, and toolChoice which are managed by the agent loop. + */ +type LanguageModelCallSettings = Omit + +/** + * Configuration for the VercelModel adapter. + * + * Extends BaseModelConfig with all LanguageModelV3 call settings (temperature, topP, topK, + * presencePenalty, frequencyPenalty, stopSequences, seed, etc.). When new fields are added + * to the Language Model Specification, they become available here automatically. + * + * Note: `maxTokens` (from BaseModelConfig) maps to `maxOutputTokens` in the underlying call. + * If both are set, `maxOutputTokens` takes precedence. + */ +export interface VercelModelConfig extends BaseModelConfig, LanguageModelCallSettings {} + +/** + * Options for creating a VercelModel instance. + */ +export interface VercelModelOptions extends Partial { + /** + * A LanguageModelV3 instance from any Vercel provider. + */ + provider: LanguageModelV3 +} + +/** + * Adapter that wraps a LanguageModelV3 instance + * for use as a Strands model provider. + * + * Implements the Model interface for any Vercel v3 compatible provider. + * Supports streaming responses, tool use, and reasoning content. + * + * @example + * ```typescript + * import { Agent } from '@strands-agents/sdk' + * import { VercelModel } from '@strands-agents/sdk/models/vercel' + * import { bedrock } from '@ai-sdk/amazon-bedrock' + * + * const agent = new Agent({ + * model: new VercelModel({ provider: bedrock('us.anthropic.claude-sonnet-4-20250514-v1:0') }), + * }) + * + * for await (const event of agent.stream('Hello!')) { + * if (event.type === 'modelContentBlockDeltaEvent' && event.delta.type === 'textDelta') { + * process.stdout.write(event.delta.text) + * } + * } + * ``` + */ +export class VercelModel extends Model { + private _provider: LanguageModelV3 + private _config: VercelModelConfig + + /** + * Creates a new VercelModel instance. + * + * @param options - The model and optional configuration + */ + constructor(options: VercelModelOptions) { + super() + const { provider, modelId, maxTokens, ...callSettings } = options + this._provider = provider + this._config = { + modelId: modelId ?? provider.modelId, + ...(maxTokens != null && { maxTokens }), + ...callSettings, + } + } + + getConfig(): VercelModelConfig { + return { ...this._config } + } + + updateConfig(config: VercelModelConfig): void { + this._config = { ...this._config, ...config } + } + + async *stream(messages: Message[], options?: StreamOptions): AsyncIterable { + const prompt = formatMessages(messages, options?.systemPrompt) + const tools = options?.toolSpecs ? formatTools(options.toolSpecs) : undefined + const toolChoice = options?.toolChoice ? formatToolChoice(options.toolChoice) : undefined + + const { modelId: _, maxTokens, ...callSettings } = this._config + + const callOptions: LanguageModelV3CallOptions = { + prompt, + ...(tools && { tools }), + ...(toolChoice && { toolChoice }), + ...(maxTokens != null && { maxOutputTokens: maxTokens }), + ...callSettings, + } + + let result + try { + result = await this._provider.doStream(callOptions) + } catch (error) { + throw classifyError(error) + } + + const reader = result.stream.getReader() + const incrementalToolCallIds = new Set() + try { + while (true) { + let readResult + try { + readResult = await reader.read() + } catch (error) { + throw classifyError(error) + } + const { done, value } = readResult + if (done) break + if (value.type === 'tool-input-start') { + incrementalToolCallIds.add(value.id) + } + // Skip complete tool-call events when we already received incremental tool-input-* events for the same call + if (value.type === 'tool-call' && incrementalToolCallIds.has(value.toolCallId)) { + continue + } + yield* mapStreamPart(value) + } + } finally { + reader.releaseLock() + } + } +} + +/** + * Classifies an error from doStream into the appropriate Strands error type. + * + * @param error - The error thrown by the Vercel provider + * @returns A classified error (ContextWindowOverflowError, ModelThrottledError, or ModelError) + */ +function classifyError(error: unknown): Error { + const message = error instanceof Error ? error.message : String(error) + + if (APICallError.isInstance(error)) { + if (error.statusCode === 429) { + logger.debug(`throttled | error_message=<${message}>`) + return new ModelThrottledError(message, { cause: error }) + } + + const searchText = (error.responseBody ?? message).toLowerCase() + if (CONTEXT_WINDOW_OVERFLOW_PATTERNS.some((pattern) => searchText.includes(pattern))) { + return new ContextWindowOverflowError(message) + } + } + + if (CONTEXT_WINDOW_OVERFLOW_PATTERNS.some((pattern) => message.toLowerCase().includes(pattern))) { + return new ContextWindowOverflowError(message) + } + + return new ModelError(`Language model stream error: ${message}`, { cause: error }) +} + +/** + * Maps a single LanguageModelV3 stream part to zero or more Strands ModelStreamEvents. + */ +function* mapStreamPart(part: LanguageModelV3StreamPart): Generator { + switch (part.type) { + case 'stream-start': + yield new ModelMessageStartEvent({ type: 'modelMessageStartEvent', role: 'assistant' }) + break + + case 'text-start': + yield new ModelContentBlockStartEvent({ type: 'modelContentBlockStartEvent' }) + break + + case 'text-delta': + yield new ModelContentBlockDeltaEvent({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'textDelta', text: part.delta }, + }) + break + + case 'text-end': + yield new ModelContentBlockStopEvent({ type: 'modelContentBlockStopEvent' }) + break + + case 'reasoning-start': + yield new ModelContentBlockStartEvent({ type: 'modelContentBlockStartEvent' }) + break + + case 'reasoning-delta': + yield new ModelContentBlockDeltaEvent({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'reasoningContentDelta', text: part.delta }, + }) + break + + case 'reasoning-end': + yield new ModelContentBlockStopEvent({ type: 'modelContentBlockStopEvent' }) + break + + case 'tool-input-start': + yield new ModelContentBlockStartEvent({ + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', name: part.toolName, toolUseId: part.id }, + }) + break + + case 'tool-input-delta': + yield new ModelContentBlockDeltaEvent({ + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: part.delta }, + }) + break + + case 'tool-input-end': + yield new ModelContentBlockStopEvent({ type: 'modelContentBlockStopEvent' }) + break + + // Some providers (e.g. Responses API) emit only the complete tool-call without incremental tool-input-* events. + // Synthesize the start/delta/stop sequence so the aggregation logic builds ToolUseBlocks correctly. + case 'tool-call': + yield new ModelContentBlockStartEvent({ + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', name: part.toolName, toolUseId: part.toolCallId }, + }) + yield new ModelContentBlockDeltaEvent({ + type: 'modelContentBlockDeltaEvent', + delta: { + type: 'toolUseInputDelta', + input: typeof part.input === 'string' ? part.input : JSON.stringify(part.input), + }, + }) + yield new ModelContentBlockStopEvent({ type: 'modelContentBlockStopEvent' }) + break + + case 'finish': + yield new ModelMetadataEvent({ + type: 'modelMetadataEvent', + usage: mapUsage(part.usage), + }) + yield new ModelMessageStopEvent({ + type: 'modelMessageStopEvent', + stopReason: mapFinishReason(part.finishReason), + }) + break + + case 'error': + throw new ModelError( + `Language model stream error: ${part.error instanceof Error ? part.error.message : JSON.stringify(part.error)}`, + { cause: part.error } + ) + + case 'response-metadata': + logger.debug(`event_type=<${part.type}>, id=<${part.id}>, modelId=<${part.modelId}> | response metadata`) + break + + default: + logger.warn(`event_type=<${part.type}> | unsupported vercel stream event type, skipping`) + break + } +} + +/** + * Maps LanguageModelV3 finish reason to Strands StopReason. + */ +function mapFinishReason(finishReason: LanguageModelV3FinishReason): StopReason { + switch (finishReason.unified) { + case 'stop': + return 'endTurn' + case 'length': + return 'maxTokens' + case 'content-filter': + return 'contentFiltered' + case 'tool-calls': + return 'toolUse' + case 'other': + return 'endTurn' + case 'error': + throw new ModelError(`model finished with error | raw=<${finishReason.raw}>`) + default: + logger.warn(`finish_reason=<${finishReason.unified}> | unknown vercel finish reason, defaulting to endTurn`) + return 'endTurn' + } +} + +/** + * Maps LanguageModelV3 usage to Strands Usage. + */ +function mapUsage(usage: LanguageModelV3Usage): Usage { + const inputTokens = usage.inputTokens.total ?? 0 + const outputTokens = usage.outputTokens.total ?? 0 + return { + inputTokens, + outputTokens, + totalTokens: inputTokens + outputTokens, + ...(usage.inputTokens.cacheRead != null && { cacheReadInputTokens: usage.inputTokens.cacheRead }), + ...(usage.inputTokens.cacheWrite != null && { cacheWriteInputTokens: usage.inputTokens.cacheWrite }), + } +} + +/** + * Converts Strands messages + system prompt to LanguageModelV3 prompt format. + */ +function formatMessages(messages: Message[], systemPrompt?: SystemPrompt): LanguageModelV3Prompt { + const prompt: LanguageModelV3Prompt = [] + + if (systemPrompt) { + if (typeof systemPrompt === 'string') { + prompt.push({ role: 'system', content: systemPrompt }) + } else { + const textBlocks: string[] = [] + let hasCachePoints = false + let hasGuardContent = false + + for (const block of systemPrompt) { + if (isTextBlock(block)) { + textBlocks.push(block.text) + } else if (block.type === 'cachePointBlock') { + hasCachePoints = true + } else if (block.type === 'guardContentBlock') { + hasGuardContent = true + } + } + + if (hasCachePoints) { + logger.warn('cache points are not supported in vercel system prompts, ignoring cache points') + } + + if (hasGuardContent) { + logger.warn('guard content is not supported in vercel system prompts, removing guard content block') + } + + const text = textBlocks.join('') + if (text) { + prompt.push({ role: 'system', content: text }) + } + } + } + + // Build a global toolCallId -> toolName map across all messages + const toolNameMap = new Map() + for (const message of messages) { + for (const block of message.content) { + if (block.type === 'toolUseBlock') { + toolNameMap.set(block.toolUseId, block.name) + } + } + } + + for (const message of messages) { + if (message.role === 'user') { + formatUserMessage(message, prompt, toolNameMap) + } else if (message.role === 'assistant') { + formatAssistantMessage(message, prompt) + } + } + + return prompt +} + +/** + * Formats a Strands user message to LanguageModelV3 format. + * Tool result blocks are extracted into separate tool messages. + * + * @param message - The user message to format + * @param prompt - The prompt array to push formatted messages into + * @param toolNameMap - Map of toolCallId to toolName for resolving tool result names + */ +function formatUserMessage(message: Message, prompt: LanguageModelV3Prompt, toolNameMap: Map): void { + const content: Array = [] + const toolResults: LanguageModelV3ToolResultPart[] = [] + + for (const block of message.content) { + switch (block.type) { + case 'textBlock': + content.push({ type: 'text', text: block.text }) + break + case 'imageBlock': + case 'documentBlock': + case 'videoBlock': + content.push(...formatMediaBlock(block)) + break + case 'toolResultBlock': + toolResults.push({ + type: 'tool-result', + toolCallId: block.toolUseId, + toolName: toolNameMap.get(block.toolUseId) ?? '', + output: formatToolResultOutput(block.status, block.content), + }) + break + default: + logger.warn(`block_type=<${block.type}> | unsupported content type in vercel user message, skipping`) + break + } + } + + if (content.length > 0) { + prompt.push({ role: 'user', content }) + } + + for (const result of toolResults) { + prompt.push({ role: 'tool', content: [result] }) + } +} + +/** + * Formats a Strands assistant message to LanguageModelV3 format. + * + * @param message - The assistant message to format + * @param prompt - The prompt array to push formatted messages into + */ +function formatAssistantMessage(message: Message, prompt: LanguageModelV3Prompt): void { + const content: Array< + LanguageModelV3TextPart | LanguageModelV3FilePart | LanguageModelV3ReasoningPart | LanguageModelV3ToolCallPart + > = [] + + for (const block of message.content) { + switch (block.type) { + case 'textBlock': + content.push({ type: 'text', text: block.text }) + break + case 'reasoningBlock': + if (block.text) { + content.push({ type: 'reasoning', text: block.text }) + } + break + case 'toolUseBlock': + content.push({ + type: 'tool-call', + toolCallId: block.toolUseId, + toolName: block.name, + input: block.input, + }) + break + case 'toolResultBlock': + logger.warn('tool result in assistant message is not supported, skipping') + break + case 'imageBlock': + case 'documentBlock': + case 'videoBlock': + content.push(...formatMediaBlock(block)) + break + default: + logger.warn(`block_type=<${block.type}> | unsupported content type in vercel assistant message, skipping`) + break + } + } + + if (content.length > 0) { + prompt.push({ role: 'assistant', content }) + } +} + +/** + * Converts an image, document, or video block to LanguageModelV3 file/text parts. + */ +function formatMediaBlock( + block: ImageBlock | DocumentBlock | VideoBlock +): Array { + const parts: Array = [] + + switch (block.type) { + case 'imageBlock': { + const mediaType = toMimeType(block.format) ?? `image/${block.format}` + if (block.source.type === 'imageSourceBytes') { + parts.push({ type: 'file', data: block.source.bytes, mediaType }) + } else if (block.source.type === 'imageSourceUrl') { + parts.push({ type: 'file', data: new URL(block.source.url), mediaType }) + } else { + logger.warn(`source_type=<${block.source.type}> | unsupported image source type, skipping`) + } + break + } + case 'documentBlock': { + const mediaType = toMimeType(block.format) ?? `application/${block.format}` + if (block.source.type === 'documentSourceBytes') { + parts.push({ type: 'file', data: block.source.bytes, mediaType }) + } else if (block.source.type === 'documentSourceText') { + parts.push({ type: 'text', text: block.source.text }) + } else if (block.source.type === 'documentSourceContentBlock') { + for (const contentBlock of block.source.content) { + parts.push({ type: 'text', text: contentBlock.text }) + } + } else { + logger.warn(`source_type=<${block.source.type}> | unsupported document source type, skipping`) + } + break + } + case 'videoBlock': { + if (block.source.type === 'videoSourceBytes') { + parts.push({ + type: 'file', + data: block.source.bytes, + mediaType: toMimeType(block.format) ?? `video/${block.format}`, + }) + } else { + logger.warn(`source_type=<${block.source.type}> | unsupported video source type, skipping`) + } + break + } + } + + return parts +} + +/** + * Formats tool result content to LanguageModelV3 ToolResultOutput. + */ +function formatToolResultOutput( + status: string, + content: ReadonlyArray +): LanguageModelV3ToolResultOutput { + if (status === 'error') { + const errorText = content + .filter((c): c is ToolResultContent & { text: string } => 'text' in c && typeof c.text === 'string') + .map((c) => c.text) + .join('\n') + return { type: 'error-text', value: errorText || 'Tool execution failed' } + } + + const value: Array<{ type: 'text'; text: string } | { type: 'file-data'; data: string; mediaType: string }> = [] + for (const c of content) { + switch (c.type) { + case 'textBlock': + value.push({ type: 'text', text: c.text }) + break + case 'jsonBlock': + value.push({ type: 'text', text: JSON.stringify(c.json) }) + break + case 'imageBlock': { + const mediaType = toMimeType(c.format) ?? `image/${c.format}` + if (c.source.type === 'imageSourceBytes') { + value.push({ type: 'file-data', data: encodeBase64(c.source.bytes), mediaType }) + } else if (c.source.type === 'imageSourceUrl') { + value.push({ type: 'text', text: c.source.url }) + } else { + logger.warn(`source_type=<${c.source.type}> | unsupported image source in vercel tool result, skipping`) + } + break + } + case 'documentBlock': { + const mediaType = toMimeType(c.format) ?? `application/${c.format}` + if (c.source.type === 'documentSourceBytes') { + value.push({ type: 'file-data', data: encodeBase64(c.source.bytes), mediaType }) + } else if (c.source.type === 'documentSourceText') { + value.push({ type: 'text', text: c.source.text }) + } else if (c.source.type === 'documentSourceContentBlock') { + for (const block of c.source.content) { + value.push({ type: 'text', text: block.text }) + } + } else { + logger.warn(`source_type=<${c.source.type}> | unsupported document source in vercel tool result, skipping`) + } + break + } + case 'videoBlock': { + const mediaType = toMimeType(c.format) ?? `video/${c.format}` + if (c.source.type === 'videoSourceBytes') { + value.push({ type: 'file-data', data: encodeBase64(c.source.bytes), mediaType }) + } else { + logger.warn(`source_type=<${c.source.type}> | unsupported video source in vercel tool result, skipping`) + } + break + } + default: + logger.warn( + `block_type=<${(c as unknown as { type: string }).type}> | unsupported content type in vercel tool result, skipping` + ) + break + } + } + return { type: 'content', value } +} + +/** + * Converts Strands ToolSpec[] to LanguageModelV3 FunctionTool[]. + */ +function formatTools(toolSpecs: ToolSpec[]): LanguageModelV3FunctionTool[] { + return toolSpecs.map((spec) => ({ + type: 'function' as const, + name: spec.name, + description: spec.description, + inputSchema: (spec.inputSchema ?? { + type: 'object', + properties: {}, + }) as LanguageModelV3FunctionTool['inputSchema'], + })) +} + +/** + * Converts Strands ToolChoice to LanguageModelV3 ToolChoice. + */ +function formatToolChoice(toolChoice: ToolChoice): LanguageModelV3ToolChoice { + if ('auto' in toolChoice) return { type: 'auto' } + if ('any' in toolChoice) return { type: 'required' } + if ('tool' in toolChoice) return { type: 'tool', toolName: toolChoice.tool.name } + return { type: 'auto' } +} + +/** + * Type guard for TextBlock instances in system prompt content. + */ +function isTextBlock(block: unknown): block is TextBlock { + return typeof block === 'object' && block !== null && 'text' in block && typeof (block as TextBlock).text === 'string' +} diff --git a/strands-ts/src/multiagent/__tests__/events.test.ts b/strands-ts/src/multiagent/__tests__/events.test.ts new file mode 100644 index 0000000000..ce60c115d9 --- /dev/null +++ b/strands-ts/src/multiagent/__tests__/events.test.ts @@ -0,0 +1,567 @@ +import { describe, expect, it } from 'vitest' +import { + MultiAgentInitializedEvent, + BeforeMultiAgentInvocationEvent, + AfterMultiAgentInvocationEvent, + BeforeNodeCallEvent, + AfterNodeCallEvent, + NodeStreamUpdateEvent, + NodeResultEvent, + NodeCancelEvent, + MultiAgentHandoffEvent, + MultiAgentResultEvent, +} from '../events.js' +import { MultiAgentResult, MultiAgentState, NodeResult, Status } from '../state.js' +import type { MultiAgent } from '../multiagent.js' +import type { AgentStreamEvent } from '../../types/agent.js' + +const mockOrchestrator: MultiAgent = { + id: 'test-orchestrator', + invoke: async () => new MultiAgentResult({ results: [], duration: 0 }), + // eslint-disable-next-line require-yield + async *stream() { + return new MultiAgentResult({ results: [], duration: 0 }) + }, + addHook: () => () => {}, +} + +describe('MultiAgentInitializedEvent', () => { + it('creates instance with correct properties', () => { + const event = new MultiAgentInitializedEvent({ orchestrator: mockOrchestrator }) + + expect(event).toEqual({ + type: 'multiAgentInitializedEvent', + orchestrator: mockOrchestrator, + }) + // @ts-expect-error verifying that property is readonly + event.orchestrator = mockOrchestrator + }) + + it('returns false for _shouldReverseCallbacks', () => { + const event = new MultiAgentInitializedEvent({ orchestrator: mockOrchestrator }) + expect(event._shouldReverseCallbacks()).toBe(false) + }) + + describe('toJSON', () => { + const event = new MultiAgentInitializedEvent({ orchestrator: mockOrchestrator }) + + it('serializes', () => { + expect(JSON.parse(JSON.stringify(event))).toStrictEqual({ type: 'multiAgentInitializedEvent' }) + }) + + it('only excludes expected fields', () => { + const json = event.toJSON() + expect(Object.keys(event).filter((k) => !(k in json))).toStrictEqual(['orchestrator']) + }) + }) +}) + +describe('BeforeMultiAgentInvocationEvent', () => { + it('creates instance with correct properties', () => { + const state = new MultiAgentState() + const event = new BeforeMultiAgentInvocationEvent({ + orchestrator: mockOrchestrator, + state, + invocationState: {}, + }) + + expect(event).toEqual({ + type: 'beforeMultiAgentInvocationEvent', + orchestrator: mockOrchestrator, + state, + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + event.orchestrator = mockOrchestrator + // @ts-expect-error verifying that property is readonly + event.state = state + }) + + it('returns false for _shouldReverseCallbacks', () => { + const state = new MultiAgentState() + const event = new BeforeMultiAgentInvocationEvent({ + orchestrator: mockOrchestrator, + state, + invocationState: {}, + }) + expect(event._shouldReverseCallbacks()).toBe(false) + }) + + describe('toJSON', () => { + const event = new BeforeMultiAgentInvocationEvent({ + orchestrator: mockOrchestrator, + state: new MultiAgentState(), + invocationState: {}, + }) + + it('serializes', () => { + expect(JSON.parse(JSON.stringify(event))).toStrictEqual({ type: 'beforeMultiAgentInvocationEvent' }) + }) + + it('only excludes expected fields', () => { + const json = event.toJSON() + expect(Object.keys(event).filter((k) => !(k in json))).toStrictEqual(['orchestrator', 'state', 'invocationState']) + }) + }) +}) + +describe('AfterMultiAgentInvocationEvent', () => { + it('creates instance with correct properties', () => { + const state = new MultiAgentState() + const event = new AfterMultiAgentInvocationEvent({ + orchestrator: mockOrchestrator, + state, + invocationState: {}, + }) + + expect(event).toEqual({ + type: 'afterMultiAgentInvocationEvent', + orchestrator: mockOrchestrator, + state, + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + event.orchestrator = mockOrchestrator + // @ts-expect-error verifying that property is readonly + event.state = state + }) + + it('returns true for _shouldReverseCallbacks', () => { + const state = new MultiAgentState() + const event = new AfterMultiAgentInvocationEvent({ + orchestrator: mockOrchestrator, + state, + invocationState: {}, + }) + expect(event._shouldReverseCallbacks()).toBe(true) + }) + + describe('toJSON', () => { + const event = new AfterMultiAgentInvocationEvent({ + orchestrator: mockOrchestrator, + state: new MultiAgentState(), + invocationState: {}, + }) + + it('serializes', () => { + expect(JSON.parse(JSON.stringify(event))).toStrictEqual({ type: 'afterMultiAgentInvocationEvent' }) + }) + + it('only excludes expected fields', () => { + const json = event.toJSON() + expect(Object.keys(event).filter((k) => !(k in json))).toStrictEqual(['orchestrator', 'state', 'invocationState']) + }) + }) +}) + +describe('BeforeNodeCallEvent', () => { + it('creates instance with correct properties', () => { + const state = new MultiAgentState() + const event = new BeforeNodeCallEvent({ + orchestrator: mockOrchestrator, + state, + nodeId: 'node-1', + invocationState: {}, + }) + + expect(event).toEqual({ + type: 'beforeNodeCallEvent', + orchestrator: mockOrchestrator, + state, + nodeId: 'node-1', + cancel: false, + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + event.orchestrator = mockOrchestrator + // @ts-expect-error verifying that property is readonly + event.state = state + // @ts-expect-error verifying that property is readonly + event.nodeId = 'node-1' + }) + + it('returns false for _shouldReverseCallbacks', () => { + const state = new MultiAgentState() + const event = new BeforeNodeCallEvent({ + orchestrator: mockOrchestrator, + state, + nodeId: 'node-1', + invocationState: {}, + }) + expect(event._shouldReverseCallbacks()).toBe(false) + }) + + it('allows cancel to be set to true', () => { + const state = new MultiAgentState() + const event = new BeforeNodeCallEvent({ + orchestrator: mockOrchestrator, + state, + nodeId: 'node-1', + invocationState: {}, + }) + + expect(event.cancel).toBe(false) + event.cancel = true + expect(event.cancel).toBe(true) + }) + + it('allows cancel to be set to a string message', () => { + const state = new MultiAgentState() + const event = new BeforeNodeCallEvent({ + orchestrator: mockOrchestrator, + state, + nodeId: 'node-1', + invocationState: {}, + }) + + event.cancel = 'node is not ready' + expect(event.cancel).toBe('node is not ready') + }) + + describe('toJSON', () => { + const event = new BeforeNodeCallEvent({ + orchestrator: mockOrchestrator, + state: new MultiAgentState(), + nodeId: 'node-1', + invocationState: {}, + }) + event.cancel = true + + it('serializes', () => { + expect(JSON.parse(JSON.stringify(event))).toStrictEqual({ + type: 'beforeNodeCallEvent', + nodeId: 'node-1', + }) + }) + + it('only excludes expected fields', () => { + const json = event.toJSON() + expect(Object.keys(event).filter((k) => !(k in json))).toStrictEqual([ + 'orchestrator', + 'state', + 'invocationState', + 'cancel', + ]) + }) + }) +}) + +describe('AfterNodeCallEvent', () => { + it('creates instance with correct properties', () => { + const state = new MultiAgentState() + const error = new Error('node failed') + const event = new AfterNodeCallEvent({ + orchestrator: mockOrchestrator, + state, + nodeId: 'node-1', + invocationState: {}, + error, + }) + + expect(event).toEqual({ + type: 'afterNodeCallEvent', + orchestrator: mockOrchestrator, + state, + nodeId: 'node-1', + invocationState: {}, + error, + }) + // @ts-expect-error verifying that property is readonly + event.orchestrator = mockOrchestrator + // @ts-expect-error verifying that property is readonly + event.state = state + // @ts-expect-error verifying that property is readonly + event.nodeId = 'node-1' + }) + + it('returns true for _shouldReverseCallbacks', () => { + const state = new MultiAgentState() + const event = new AfterNodeCallEvent({ + orchestrator: mockOrchestrator, + state, + nodeId: 'node-1', + invocationState: {}, + }) + expect(event._shouldReverseCallbacks()).toBe(true) + }) + + describe('toJSON', () => { + const event = new AfterNodeCallEvent({ + orchestrator: mockOrchestrator, + state: new MultiAgentState(), + nodeId: 'node-1', + invocationState: {}, + error: new Error('node failed'), + }) + + it('serializes', () => { + expect(JSON.parse(JSON.stringify(event))).toStrictEqual({ + type: 'afterNodeCallEvent', + nodeId: 'node-1', + error: { message: 'node failed' }, + }) + }) + + it('serializes without error', () => { + const event = new AfterNodeCallEvent({ + orchestrator: mockOrchestrator, + state: new MultiAgentState(), + nodeId: 'node-1', + invocationState: {}, + }) + + expect(JSON.parse(JSON.stringify(event))).toStrictEqual({ + type: 'afterNodeCallEvent', + nodeId: 'node-1', + }) + }) + + it('only excludes expected fields', () => { + const json = event.toJSON() + expect(Object.keys(event).filter((k) => !(k in json))).toStrictEqual(['orchestrator', 'state', 'invocationState']) + }) + }) +}) + +describe('NodeStreamUpdateEvent', () => { + it('creates instance with correct properties', () => { + const state = new MultiAgentState() + const innerEvent = { source: 'agent', event: { type: 'beforeInvocationEvent' } as AgentStreamEvent } as const + const event = new NodeStreamUpdateEvent({ + nodeId: 'node-1', + nodeType: 'agentNode', + state, + inner: innerEvent, + invocationState: {}, + }) + + expect(event).toEqual({ + type: 'nodeStreamUpdateEvent', + nodeId: 'node-1', + nodeType: 'agentNode', + state, + inner: innerEvent, + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + event.nodeId = 'node-1' + // @ts-expect-error verifying that property is readonly + event.nodeType = 'agentNode' + // @ts-expect-error verifying that property is readonly + event.state = state + // @ts-expect-error verifying that property is readonly + event.inner = innerEvent + }) + + describe('toJSON', () => { + const innerEvent = { source: 'agent', event: { type: 'beforeInvocationEvent' } as AgentStreamEvent } as const + const event = new NodeStreamUpdateEvent({ + nodeId: 'node-1', + nodeType: 'agentNode', + state: new MultiAgentState(), + inner: innerEvent, + invocationState: {}, + }) + + it('serializes', () => { + expect(JSON.parse(JSON.stringify(event))).toStrictEqual({ + type: 'nodeStreamUpdateEvent', + nodeId: 'node-1', + nodeType: 'agentNode', + inner: { source: 'agent', event: { type: 'beforeInvocationEvent' } }, + }) + }) + + it('only excludes expected fields', () => { + const json = event.toJSON() + expect(Object.keys(event).filter((k) => !(k in json))).toStrictEqual(['state', 'invocationState']) + }) + }) +}) + +describe('NodeResultEvent', () => { + it('creates instance with correct properties', () => { + const state = new MultiAgentState() + const result = new NodeResult({ nodeId: 'node-1', status: Status.COMPLETED, duration: 100 }) + const event = new NodeResultEvent({ nodeId: 'node-1', nodeType: 'agentNode', state, result, invocationState: {} }) + + expect(event).toEqual({ + type: 'nodeResultEvent', + nodeId: 'node-1', + nodeType: 'agentNode', + state, + result, + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + event.nodeId = 'node-1' + // @ts-expect-error verifying that property is readonly + event.nodeType = 'agentNode' + // @ts-expect-error verifying that property is readonly + event.state = state + // @ts-expect-error verifying that property is readonly + event.result = result + }) + + describe('toJSON', () => { + const event = new NodeResultEvent({ + nodeId: 'node-1', + nodeType: 'agentNode', + state: new MultiAgentState(), + result: new NodeResult({ nodeId: 'node-1', status: Status.COMPLETED, duration: 100 }), + invocationState: {}, + }) + + it('serializes', () => { + expect(JSON.parse(JSON.stringify(event))).toStrictEqual({ + type: 'nodeResultEvent', + nodeId: 'node-1', + nodeType: 'agentNode', + result: { + type: 'nodeResult', + nodeId: 'node-1', + status: 'COMPLETED', + duration: 100, + content: [], + }, + }) + }) + + it('only excludes expected fields', () => { + const json = event.toJSON() + expect(Object.keys(event).filter((k) => !(k in json))).toStrictEqual(['state', 'invocationState']) + }) + }) +}) + +describe('NodeCancelEvent', () => { + it('creates instance with correct properties', () => { + const state = new MultiAgentState() + const event = new NodeCancelEvent({ nodeId: 'node-1', state, message: 'cancelled by hook', invocationState: {} }) + + expect(event).toEqual({ + type: 'nodeCancelEvent', + nodeId: 'node-1', + state, + message: 'cancelled by hook', + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + event.nodeId = 'node-1' + // @ts-expect-error verifying that property is readonly + event.state = state + // @ts-expect-error verifying that property is readonly + event.message = 'cancelled by hook' + }) + + describe('toJSON', () => { + const event = new NodeCancelEvent({ + nodeId: 'node-1', + state: new MultiAgentState(), + message: 'cancelled by hook', + invocationState: {}, + }) + + it('serializes', () => { + expect(JSON.parse(JSON.stringify(event))).toStrictEqual({ + type: 'nodeCancelEvent', + nodeId: 'node-1', + message: 'cancelled by hook', + }) + }) + + it('only excludes expected fields', () => { + const json = event.toJSON() + expect(Object.keys(event).filter((k) => !(k in json))).toStrictEqual(['state', 'invocationState']) + }) + }) +}) + +describe('MultiAgentHandoffEvent', () => { + it('creates instance with correct properties', () => { + const state = new MultiAgentState() + const event = new MultiAgentHandoffEvent({ + source: 'node-a', + targets: ['node-b', 'node-c'], + state, + invocationState: {}, + }) + + expect(event).toEqual({ + type: 'multiAgentHandoffEvent', + source: 'node-a', + targets: ['node-b', 'node-c'], + state, + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + event.source = 'node-a' + // @ts-expect-error verifying that property is readonly + event.targets = [] + // @ts-expect-error verifying that property is readonly + event.state = state + }) + + describe('toJSON', () => { + const event = new MultiAgentHandoffEvent({ + source: 'node-a', + targets: ['node-b', 'node-c'], + state: new MultiAgentState(), + invocationState: {}, + }) + + it('serializes', () => { + expect(JSON.parse(JSON.stringify(event))).toStrictEqual({ + type: 'multiAgentHandoffEvent', + source: 'node-a', + targets: ['node-b', 'node-c'], + }) + }) + + it('only excludes expected fields', () => { + const json = event.toJSON() + expect(Object.keys(event).filter((k) => !(k in json))).toStrictEqual(['state', 'invocationState']) + }) + }) +}) + +describe('MultiAgentResultEvent', () => { + it('creates instance with correct properties', () => { + const result = new MultiAgentResult({ results: [], duration: 0 }) + const event = new MultiAgentResultEvent({ result, invocationState: {} }) + + expect(event).toEqual({ + type: 'multiAgentResultEvent', + result, + invocationState: {}, + }) + // @ts-expect-error verifying that property is readonly + event.result = result + }) + + describe('toJSON', () => { + const event = new MultiAgentResultEvent({ + result: new MultiAgentResult({ results: [], duration: 500 }), + invocationState: {}, + }) + + it('serializes', () => { + expect(JSON.parse(JSON.stringify(event))).toStrictEqual({ + type: 'multiAgentResultEvent', + result: { + type: 'multiAgentResult', + status: 'COMPLETED', + results: [], + content: [], + duration: 500, + usage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 }, + }, + }) + }) + + it('only excludes expected fields', () => { + const json = event.toJSON() + expect(Object.keys(event).filter((k) => !(k in json))).toStrictEqual(['invocationState']) + }) + }) +}) diff --git a/strands-ts/src/multiagent/__tests__/graph.invocation-state.test.ts b/strands-ts/src/multiagent/__tests__/graph.invocation-state.test.ts new file mode 100644 index 0000000000..161b79fa90 --- /dev/null +++ b/strands-ts/src/multiagent/__tests__/graph.invocation-state.test.ts @@ -0,0 +1,122 @@ +import { describe, expect, it } from 'vitest' +import { Agent } from '../../agent/agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { BeforeModelCallEvent } from '../../hooks/events.js' +import { TextBlock } from '../../types/messages.js' +import { Graph } from '../graph.js' +import { + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentHandoffEvent, + MultiAgentResultEvent, + NodeResultEvent, + NodeStreamUpdateEvent, +} from '../events.js' +import type { InvocationState } from '../../types/agent.js' + +describe('Graph invocationState forwarding', () => { + it('forwards invocationState to every node and mutations from one node are visible to the next', async () => { + const nodeAObserved: InvocationState[] = [] + const nodeBObserved: InvocationState[] = [] + + const agentA = new Agent({ + model: new MockMessageModel().addTurn(new TextBlock('A done')), + printer: false, + id: 'a', + }) + agentA.addHook(BeforeModelCallEvent, (event) => { + nodeAObserved.push(event.invocationState) + event.invocationState.touchedByA = true + }) + + const agentB = new Agent({ + model: new MockMessageModel().addTurn(new TextBlock('B done')), + printer: false, + id: 'b', + }) + agentB.addHook(BeforeModelCallEvent, (event) => { + nodeBObserved.push(event.invocationState) + }) + + const graph = new Graph({ + nodes: [agentA, agentB], + edges: [{ source: 'a', target: 'b' }], + }) + + const state: InvocationState = { requestId: 'r-1' } + await graph.invoke('hello', { invocationState: state }) + + // Both nodes observe the same object reference. + expect(nodeAObserved[0]).toBe(state) + expect(nodeBObserved[0]).toBe(state) + + // Node B sees node A's mutation. + expect(nodeBObserved[0]?.touchedByA).toBe(true) + expect(state.touchedByA).toBe(true) + }) + + it('defaults invocationState to {} when none is passed', async () => { + let observed: InvocationState | undefined + + const agentA = new Agent({ + model: new MockMessageModel().addTurn(new TextBlock('A done')), + printer: false, + id: 'a', + }) + agentA.addHook(BeforeModelCallEvent, (event) => { + observed = event.invocationState + }) + + const graph = new Graph({ + nodes: [agentA], + edges: [], + }) + + await graph.invoke('hello') + + expect(observed).toEqual({}) + }) + + it('every orchestrator and node event in a run carries the same invocationState reference', async () => { + const agentA = new Agent({ + model: new MockMessageModel().addTurn(new TextBlock('A done')), + printer: false, + id: 'a', + }) + const agentB = new Agent({ + model: new MockMessageModel().addTurn(new TextBlock('B done')), + printer: false, + id: 'b', + }) + + const graph = new Graph({ + nodes: [agentA, agentB], + edges: [{ source: 'a', target: 'b' }], + }) + + const state: InvocationState = { requestId: 'r-1' } + const observed: { label: string; ref: InvocationState }[] = [] + + const record = (label: string, ref: InvocationState): void => { + observed.push({ label, ref }) + } + graph.addHook(BeforeMultiAgentInvocationEvent, (e) => record('BeforeMultiAgentInvocation', e.invocationState)) + graph.addHook(AfterMultiAgentInvocationEvent, (e) => record('AfterMultiAgentInvocation', e.invocationState)) + graph.addHook(BeforeNodeCallEvent, (e) => record(`BeforeNodeCall:${e.nodeId}`, e.invocationState)) + graph.addHook(AfterNodeCallEvent, (e) => record(`AfterNodeCall:${e.nodeId}`, e.invocationState)) + graph.addHook(NodeStreamUpdateEvent, (e) => record(`NodeStreamUpdate:${e.nodeId}`, e.invocationState)) + graph.addHook(NodeResultEvent, (e) => record(`NodeResult:${e.nodeId}`, e.invocationState)) + graph.addHook(MultiAgentHandoffEvent, (e) => record('MultiAgentHandoff', e.invocationState)) + graph.addHook(MultiAgentResultEvent, (e) => record('MultiAgentResult', e.invocationState)) + + await graph.invoke('hello', { invocationState: state }) + + // Every event observed at the orchestrator level must share the caller's reference. + expect(observed.length).toBeGreaterThan(0) + for (const { label, ref } of observed) { + expect(ref, `event=${label} saw a different invocationState object`).toBe(state) + } + }) +}) diff --git a/strands-ts/src/multiagent/__tests__/graph.test.ts b/strands-ts/src/multiagent/__tests__/graph.test.ts new file mode 100644 index 0000000000..23863a432a --- /dev/null +++ b/strands-ts/src/multiagent/__tests__/graph.test.ts @@ -0,0 +1,870 @@ +import { describe, expect, it, vi } from 'vitest' +import { Agent } from '../../agent/agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { MockSnapshotStorage } from '../../__fixtures__/mock-storage-provider.js' +import { collectGenerator } from '../../__fixtures__/model-test-helpers.js' +import { createCancellableAgent } from '../../__fixtures__/agent-helpers.js' +import { AfterNodeCallEvent, BeforeNodeCallEvent, MultiAgentInitializedEvent } from '../events.js' +import { TextBlock, type ContentBlockData } from '../../types/messages.js' +import { Status, MultiAgentState } from '../state.js' +import { AgentNode, MultiAgentNode } from '../nodes.js' +import { Graph } from '../graph.js' +import { SessionManager } from '../../session/session-manager.js' + +function makeAgent(id: string, text = 'reply'): Agent { + const model = new MockMessageModel().addTurn(new TextBlock(text)) + return new Agent({ model, printer: false, id }) +} + +describe('Graph', () => { + describe('constructor', () => { + it('defaults id to "graph"', () => { + const graph = new Graph({ + nodes: [makeAgent('a')], + edges: [], + }) + expect(graph.id).toBe('graph') + }) + + it('accepts a custom id', () => { + const graph = new Graph({ + nodes: [makeAgent('a')], + edges: [], + id: 'my-graph', + }) + expect(graph.id).toBe('my-graph') + }) + + it('accepts agent node options', () => { + const graph = new Graph({ + nodes: [{ agent: makeAgent('a') }], + edges: [], + }) + expect(graph.nodes.get('a')).toBeInstanceOf(AgentNode) + }) + + it('accepts multiAgent node options', () => { + const inner = new Graph({ + id: 'inner', + nodes: [makeAgent('x')], + edges: [], + }) + + const graph = new Graph({ + nodes: [{ type: 'multiAgent', orchestrator: inner }], + edges: [], + }) + expect(graph.nodes.get('inner')).toBeInstanceOf(MultiAgentNode) + }) + + it('accepts pre-built Node instances', () => { + const node = new AgentNode({ agent: makeAgent('a') }) + const graph = new Graph({ + nodes: [node], + edges: [], + }) + expect(graph.nodes.get('a')).toBe(node) + }) + + it('accepts edge options', () => { + const graph = new Graph({ + nodes: [makeAgent('a'), makeAgent('b')], + edges: [{ source: 'a', target: 'b' }], + }) + expect(graph.edges).toHaveLength(1) + expect(graph.edges[0]).toEqual( + expect.objectContaining({ + source: expect.objectContaining({ id: 'a' }), + target: expect.objectContaining({ id: 'b' }), + }) + ) + }) + + it('throws on duplicate node IDs', () => { + const agent = makeAgent('a') + expect( + () => + new Graph({ + nodes: [agent, agent], + edges: [], + }) + ).toThrow('node_id= | duplicate node id') + }) + + it('throws on edge referencing unknown source node', () => { + expect( + () => + new Graph({ + nodes: [makeAgent('a')], + edges: [['missing', 'a']], + }) + ).toThrow('source= | edge references unknown source node') + }) + + it('throws on edge referencing unknown target node', () => { + expect( + () => + new Graph({ + nodes: [makeAgent('a')], + edges: [['a', 'missing']], + }) + ).toThrow('target= | edge references unknown target node') + }) + + it('throws when graph has no source nodes', () => { + expect( + () => + new Graph({ + nodes: [makeAgent('a'), makeAgent('b')], + edges: [ + ['a', 'b'], + ['b', 'a'], + ], + }) + ).toThrow('graph has no source nodes') + }) + + it('throws on unreachable nodes', () => { + expect( + () => + new Graph({ + nodes: [makeAgent('a'), makeAgent('b'), makeAgent('island1'), makeAgent('island2')], + edges: [ + ['a', 'b'], + ['island1', 'island2'], + ['island2', 'island1'], + ], + }) + ).toThrow('node_id= | unreachable from any source node') + }) + + it('throws when explicit source references unknown node', () => { + expect( + () => + new Graph({ + nodes: [makeAgent('a')], + edges: [], + sources: ['missing'], + }) + ).toThrow('source= | source references unknown node') + }) + + it('throws when maxSteps < 1', () => { + expect( + () => + new Graph({ + nodes: [makeAgent('a')], + edges: [], + maxSteps: 0, + }) + ).toThrow('max_steps=<0> | must be at least 1') + }) + + it('throws when maxConcurrency < 1', () => { + expect( + () => + new Graph({ + nodes: [makeAgent('a')], + edges: [], + maxConcurrency: 0, + }) + ).toThrow('max_concurrency=<0> | must be at least 1') + }) + + it('defaults maxConcurrency, maxSteps, timeout, and nodeTimeout to Infinity', () => { + const graph = new Graph({ + nodes: [makeAgent('a')], + edges: [], + }) + expect(graph.config.maxConcurrency).toBe(Infinity) + expect(graph.config.maxSteps).toBe(Infinity) + expect(graph.config.timeout).toBe(Infinity) + expect(graph.config.nodeTimeout).toBe(Infinity) + }) + + it('throws when timeout < 1', () => { + expect( + () => + new Graph({ + nodes: [makeAgent('a')], + edges: [], + timeout: 0, + }) + ).toThrow('timeout=<0> | must be at least 1') + }) + + it('throws when nodeTimeout < 1', () => { + expect( + () => + new Graph({ + nodes: [makeAgent('a')], + edges: [], + nodeTimeout: 0, + }) + ).toThrow('node_timeout=<0> | must be at least 1') + }) + }) + + describe('invoke', () => { + it('executes linear graph (A -> B -> C) in order', async () => { + const graph = new Graph({ + nodes: [makeAgent('a', 'a-reply'), makeAgent('b', 'b-reply'), makeAgent('c', 'c-reply')], + edges: [ + ['a', 'b'], + ['b', 'c'], + ], + }) + + const result = await graph.invoke('start') + + expect(result).toEqual( + expect.objectContaining({ + status: Status.COMPLETED, + content: expect.arrayContaining([expect.objectContaining({ type: 'textBlock', text: 'c-reply' })]), + duration: expect.any(Number), + }) + ) + expect(result.results.map((r) => r.nodeId)).toStrictEqual(['a', 'b', 'c']) + }) + + it('executes parallel graph (A -> B, A -> C) with B and C after A', async () => { + const graph = new Graph({ + nodes: [makeAgent('a', 'a-reply'), makeAgent('b', 'b-reply'), makeAgent('c', 'c-reply')], + edges: [ + ['a', 'b'], + ['a', 'c'], + ], + }) + + const result = await graph.invoke('start') + + expect(result).toEqual( + expect.objectContaining({ + status: Status.COMPLETED, + content: expect.arrayContaining([ + expect.objectContaining({ type: 'textBlock', text: 'b-reply' }), + expect.objectContaining({ type: 'textBlock', text: 'c-reply' }), + ]), + duration: expect.any(Number), + }) + ) + expect(result.results.map((r) => r.nodeId).sort()).toStrictEqual(['a', 'b', 'c']) + }) + + it('waits for all dependencies before executing join node (A -> C, B -> C)', async () => { + const graph = new Graph({ + nodes: [makeAgent('a', 'a-reply'), makeAgent('b', 'b-reply'), makeAgent('c', 'c-reply')], + edges: [ + ['a', 'c'], + ['b', 'c'], + ], + maxConcurrency: 1, + }) + + const result = await graph.invoke('start') + + expect(result).toEqual( + expect.objectContaining({ + status: Status.COMPLETED, + content: expect.arrayContaining([expect.objectContaining({ type: 'textBlock', text: 'c-reply' })]), + duration: expect.any(Number), + }) + ) + expect(result.results).toHaveLength(3) + }) + + it('executes nested graph through MultiAgentNode', async () => { + const inner = new Graph({ + id: 'inner', + nodes: [makeAgent('x', 'inner-reply')], + edges: [], + }) + + const graph = new Graph({ + nodes: [makeAgent('a', 'a-reply'), inner], + edges: [['a', 'inner']], + }) + + const result = await graph.invoke('start') + + expect(result).toEqual( + expect.objectContaining({ + status: Status.COMPLETED, + content: expect.arrayContaining([expect.objectContaining({ type: 'textBlock', text: 'inner-reply' })]), + duration: expect.any(Number), + }) + ) + expect(result.results.map((r) => r.nodeId)).toStrictEqual(['a', 'inner']) + }) + + it('uses explicit sources instead of auto-detection', async () => { + const graph = new Graph({ + nodes: [makeAgent('a'), makeAgent('b')], + edges: [['a', 'b'], { source: 'b', target: 'a', handler: () => false }], + sources: ['a'], + }) + + const result = await graph.invoke('go') + + expect(result).toEqual( + expect.objectContaining({ + status: Status.COMPLETED, + duration: expect.any(Number), + }) + ) + expect(result.results.map((r) => r.nodeId)).toStrictEqual(['a', 'b']) + }) + + it('evaluates conditional edges', async () => { + const graph = new Graph({ + nodes: [makeAgent('a', 'a-reply'), makeAgent('b', 'b-reply'), makeAgent('c', 'c-reply')], + edges: [ + { source: 'a', target: 'b', handler: () => true }, + { source: 'a', target: 'c', handler: () => false }, + ], + }) + + const result = await graph.invoke('start') + + expect(result).toEqual( + expect.objectContaining({ + status: Status.COMPLETED, + duration: expect.any(Number), + }) + ) + expect(result.results.map((r) => r.nodeId)).toStrictEqual(['a', 'b']) + }) + + it('evaluates conditional edges on join node (A -> C false, B -> C)', async () => { + const graph = new Graph({ + nodes: [makeAgent('a', 'a-reply'), makeAgent('b', 'b-reply'), makeAgent('c', 'c-reply')], + edges: [{ source: 'a', target: 'c', handler: () => false }, ['b', 'c']], + }) + + const result = await graph.invoke('start') + + expect(result.status).toBe(Status.COMPLETED) + expect(result.results.map((r) => r.nodeId).sort()).toStrictEqual(['a', 'b']) + }) + + it('evaluates conditional edges on join node (A -> C true, B -> C true)', async () => { + const graph = new Graph({ + nodes: [makeAgent('a', 'a-reply'), makeAgent('b', 'b-reply'), makeAgent('c', 'c-reply')], + edges: [ + { source: 'a', target: 'c', handler: () => true }, + { source: 'b', target: 'c', handler: () => true }, + ], + }) + + const result = await graph.invoke('start') + + expect(result.status).toBe(Status.COMPLETED) + expect(result.results.map((r) => r.nodeId).sort()).toStrictEqual(['a', 'b', 'c']) + }) + + it('passes task + dependency content to downstream nodes', async () => { + const agentB = makeAgent('b') + const streamSpy = vi.spyOn(agentB, 'stream') + + const graph = new Graph({ + nodes: [makeAgent('a', 'from-a'), agentB], + edges: [['a', 'b']], + }) + + await graph.invoke('task-input') + + expect(streamSpy).toHaveBeenCalled() + const input = streamSpy.mock.calls[0]![0] as TextBlock[] + expect(input.map((b) => b.text)).toStrictEqual(['task-input', '[node: a]', 'from-a']) + }) + + it('converts ContentBlockData[] input to ContentBlock instances for downstream nodes', async () => { + const agentB = makeAgent('b') + const streamSpy = vi.spyOn(agentB, 'stream') + + const graph = new Graph({ + nodes: [makeAgent('a', 'from-a'), agentB], + edges: [['a', 'b']], + }) + + const dataInput: ContentBlockData[] = [{ text: 'data-input' }] + await graph.invoke(dataInput) + + expect(streamSpy).toHaveBeenCalled() + const input = streamSpy.mock.calls[0]![0] as TextBlock[] + expect(input[0]).toBeInstanceOf(TextBlock) + expect(input.map((b) => b.text)).toStrictEqual(['data-input', '[node: a]', 'from-a']) + }) + + it('returns failed result when agent throws', async () => { + const model = new MockMessageModel().addTurn(new Error('agent exploded')) + const agent = new Agent({ model, printer: false, id: 'a' }) + + const graph = new Graph({ + nodes: [agent, makeAgent('b', 'b-reply')], + edges: [['a', 'b']], + }) + + const result = await graph.invoke('go') + + expect(result).toEqual( + expect.objectContaining({ + status: Status.FAILED, + duration: expect.any(Number), + }) + ) + expect(result.results).toHaveLength(1) + expect(result.results[0]).toEqual(expect.objectContaining({ nodeId: 'a', status: Status.FAILED })) + }) + + it('propagates unexpected errors from node execution', async () => { + const graph = new Graph({ + nodes: [makeAgent('a')], + edges: [], + }) + + const node = graph.nodes.get('a')! + // eslint-disable-next-line require-yield + vi.spyOn(node, 'stream').mockImplementation(async function* () { + throw new Error('unexpected failure') + }) + + await expect(graph.invoke('go')).rejects.toThrow('unexpected failure') + }) + + it('throws when maxSteps is exceeded', async () => { + const graph = new Graph({ + nodes: [makeAgent('a'), makeAgent('b'), makeAgent('c')], + edges: [ + ['a', 'b'], + ['b', 'c'], + ], + maxSteps: 2, + }) + + await expect(graph.invoke('go')).rejects.toThrow('max steps reached') + }) + + it('throws when a node exceeds nodeTimeout', async () => { + const graph = new Graph({ + nodes: [{ agent: createCancellableAgent('slow', 100) }], + edges: [], + nodeTimeout: 20, + }) + + await expect(graph.invoke('go')).rejects.toThrow(/node_timeout=<20>, node_id=/) + }) + + it('applies per-node timeout over nodeTimeout', async () => { + const graph = new Graph({ + nodes: [{ agent: createCancellableAgent('slow', 100), timeout: 15 }], + edges: [], + nodeTimeout: 10_000, + }) + + await expect(graph.invoke('go')).rejects.toThrow(/node_timeout=<15>, node_id=/) + }) + + it('does not throw when nodeTimeout is Infinity', async () => { + const graph = new Graph({ + nodes: [{ agent: createCancellableAgent('a', 20) }], + edges: [], + nodeTimeout: Infinity, + }) + + const result = await graph.invoke('go') + expect(result.results).toHaveLength(1) + expect(result.results[0]?.status).toBe(Status.COMPLETED) + }) + + it('per-node timeout of Infinity disables a finite nodeTimeout', async () => { + const graph = new Graph({ + nodes: [{ agent: createCancellableAgent('slow', 30), timeout: Infinity }], + edges: [], + nodeTimeout: 10, + }) + + const result = await graph.invoke('go') + expect(result.results).toHaveLength(1) + expect(result.results[0]?.status).toBe(Status.COMPLETED) + }) + + it('throws when timeout is exceeded', async () => { + const graph = new Graph({ + nodes: [{ agent: createCancellableAgent('a', 30) }, { agent: createCancellableAgent('b', 30) }], + edges: [['a', 'b']], + timeout: 20, + }) + + await expect(graph.invoke('go')).rejects.toThrow(/timeout=<20>/) + }) + + it('calls initialize only once across invocations', async () => { + let callCount = 0 + + const graph = new Graph({ + nodes: [makeAgent('a')], + edges: [], + }) + + graph.addHook(MultiAgentInitializedEvent, () => { + callCount++ + }) + + await graph.invoke('first') + await graph.invoke('second') + + expect(callCount).toBe(1) + }) + + it('respects maxConcurrency limit', async () => { + let concurrent = 0 + let maxConcurrent = 0 + + const graph = new Graph({ + nodes: [makeAgent('a'), makeAgent('b'), makeAgent('c')], + edges: [ + ['a', 'b'], + ['a', 'c'], + ], + maxConcurrency: 1, + }) + + graph.addHook(BeforeNodeCallEvent, () => { + concurrent++ + maxConcurrent = Math.max(maxConcurrent, concurrent) + }) + graph.addHook(AfterNodeCallEvent, () => { + concurrent-- + }) + + const result = await graph.invoke('go') + + expect(result.status).toBe(Status.COMPLETED) + expect(maxConcurrent).toBe(1) + }) + + it('preserves agent messages and state after execution', async () => { + const agent = makeAgent('a', 'reply') + const messagesBefore = [...agent.messages] + const stateBefore = agent.appState.getAll() + + const graph = new Graph({ + nodes: [agent], + edges: [], + }) + + await graph.invoke('hello') + + expect(agent.messages).toStrictEqual(messagesBefore) + expect(agent.appState.getAll()).toStrictEqual(stateBefore) + }) + + it('executes join node exactly once when all parents complete concurrently', async () => { + const graph = new Graph({ + nodes: [makeAgent('a', 'a-reply'), makeAgent('b', 'b-reply'), makeAgent('c', 'c-reply')], + edges: [ + ['a', 'c'], + ['b', 'c'], + ], + }) + + const nodeC = graph.nodes.get('c')! + const streamSpy = vi.spyOn(nodeC, 'stream') + + const result = await graph.invoke('go') + + expect(result.status).toBe(Status.COMPLETED) + expect(streamSpy).toHaveBeenCalledTimes(1) + }) + + it('re-executes node in a cycle when conditional edge allows re-entry', async () => { + let visits = 0 + + const graph = new Graph({ + nodes: [makeAgent('a', 'a-reply')], + edges: [ + { + source: 'a', + target: 'a', + handler: () => { + visits++ + return visits < 2 + }, + }, + ], + sources: ['a'], + }) + + const result = await graph.invoke('go') + + expect(result.status).toBe(Status.COMPLETED) + expect(result.results).toHaveLength(2) + expect(result.results.map((r) => r.nodeId)).toStrictEqual(['a', 'a']) + expect(visits).toBe(2) + }) + }) + + describe('stream', () => { + it('yields lifecycle events in correct order for single node', async () => { + const graph = new Graph({ + nodes: [makeAgent('a', 'a-reply')], + edges: [], + }) + + const { items, result } = await collectGenerator(graph.stream('go')) + const eventTypes = items.map((e) => e.type) + + expect(result.status).toBe(Status.COMPLETED) + expect(result.results.map((r) => r.nodeId)).toStrictEqual(['a']) + expect(eventTypes).toStrictEqual([ + 'beforeMultiAgentInvocationEvent', + 'beforeNodeCallEvent', + ...eventTypes.filter((t) => t === 'nodeStreamUpdateEvent'), + 'nodeResultEvent', + 'afterNodeCallEvent', + 'afterMultiAgentInvocationEvent', + 'multiAgentResultEvent', + ]) + }) + + it('yields handoff events on transitions between nodes', async () => { + const graph = new Graph({ + nodes: [makeAgent('a', 'a-reply'), makeAgent('b', 'b-reply')], + edges: [['a', 'b']], + }) + + const { items } = await collectGenerator(graph.stream('go')) + + const handoffEvents = items.filter((e) => e.type === 'multiAgentHandoffEvent') + expect(handoffEvents).toHaveLength(1) + + expect(handoffEvents[0]).toEqual( + expect.objectContaining({ + type: 'multiAgentHandoffEvent', + source: 'a', + targets: ['b'], + state: expect.any(MultiAgentState), + }) + ) + }) + + it('returns cancelled result when cancel is true', async () => { + const graph = new Graph({ + nodes: [makeAgent('a')], + edges: [], + }) + + graph.addHook(BeforeNodeCallEvent, (event: BeforeNodeCallEvent) => { + event.cancel = true + }) + + const { items, result } = await collectGenerator(graph.stream('go')) + + expect(result.status).toBe(Status.CANCELLED) + expect(result.results).toHaveLength(1) + expect(result.results[0]).toEqual(expect.objectContaining({ nodeId: 'a', status: Status.CANCELLED, duration: 0 })) + + const cancelEvent = items.find((e) => e.type === 'nodeCancelEvent') + expect(cancelEvent).toEqual( + expect.objectContaining({ nodeId: 'a', state: expect.any(MultiAgentState), message: 'node cancelled by hook' }) + ) + }) + + it('returns cancelled result with custom message when cancel is a string', async () => { + const graph = new Graph({ + nodes: [makeAgent('a')], + edges: [], + }) + + graph.addHook(BeforeNodeCallEvent, (event: BeforeNodeCallEvent) => { + event.cancel = 'node not ready' + }) + + const { items, result } = await collectGenerator(graph.stream('go')) + + expect(result.status).toBe(Status.CANCELLED) + + const cancelEvent = items.find((e) => e.type === 'nodeCancelEvent') + expect(cancelEvent).toEqual( + expect.objectContaining({ nodeId: 'a', state: expect.any(MultiAgentState), message: 'node not ready' }) + ) + }) + + it('cleans up running nodes when consumer breaks mid-stream', async () => { + const graph = new Graph({ + nodes: [makeAgent('a', 'a-reply'), makeAgent('b', 'b-reply')], + edges: [['a', 'b']], + }) + + const gen = graph.stream('go') + const first = await gen.next() + expect(first.done).toBe(false) + + // Simulates consumer break — should not hang waiting for node streams + const result = await gen.return(undefined as never) + expect(result.done).toBe(true) + }) + }) + + describe('resume with session manager', () => { + function makeSessionManager(storage: MockSnapshotStorage): SessionManager { + return new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + }) + } + + it('throws when sessionManager appears in both constructor arg and plugins', () => { + const sm = makeSessionManager(new MockSnapshotStorage()) + expect( + () => + new Graph({ + nodes: [makeAgent('a')], + edges: [], + sessionManager: sm, + plugins: [sm], + }) + ).toThrow('sessionManager was provided as both a constructor argument and in the plugins array') + }) + + it('resumes from the next ready node after a linear graph stops (A→B→C, A done, resumes at B)', async () => { + const storage = new MockSnapshotStorage() + + const graph1 = new Graph({ + id: 'my-graph', + nodes: [makeAgent('a', 'a-reply'), makeAgent('b', 'b-reply'), makeAgent('c', 'c-reply')], + edges: [ + ['a', 'b'], + ['b', 'c'], + ], + maxSteps: 1, + sessionManager: makeSessionManager(storage), + }) + + await expect(graph1.invoke('start')).rejects.toThrow('max steps reached') + + const graph2 = new Graph({ + id: 'my-graph', + nodes: [makeAgent('a', 'a-reply'), makeAgent('b', 'b-reply'), makeAgent('c', 'c-reply')], + edges: [ + ['a', 'b'], + ['b', 'c'], + ], + sessionManager: makeSessionManager(storage), + }) + + const result = await graph2.invoke('start') + + expect(result.status).toBe(Status.COMPLETED) + const completedIds = result.results.filter((r) => r.status === Status.COMPLETED).map((r) => r.nodeId) + expect(completedIds).toStrictEqual(['a', 'b', 'c']) + }) + + it('resumes parallel branches independently (A→B, A→C, B done, C cancelled, resumes at C)', async () => { + const storage = new MockSnapshotStorage() + + const graph1 = new Graph({ + id: 'my-graph', + nodes: [makeAgent('a', 'a-reply'), makeAgent('b', 'b-reply'), makeAgent('c', 'c-reply')], + edges: [ + ['a', 'b'], + ['a', 'c'], + ], + plugins: [makeSessionManager(storage)], + maxConcurrency: 1, + }) + + graph1.addHook(BeforeNodeCallEvent, (event: BeforeNodeCallEvent) => { + if (event.nodeId === 'c') event.cancel = 'simulated stop' + }) + + await graph1.invoke('start') + + const graph2 = new Graph({ + id: 'my-graph', + nodes: [makeAgent('a', 'a-reply'), makeAgent('b', 'b-reply'), makeAgent('c', 'c-reply')], + edges: [ + ['a', 'b'], + ['a', 'c'], + ], + plugins: [makeSessionManager(storage)], + }) + + const result = await graph2.invoke('start') + + const completedIds = result.results.filter((r) => r.status === Status.COMPLETED).map((r) => r.nodeId) + expect(completedIds).toContain('a') + expect(completedIds).toContain('b') + expect(completedIds).toContain('c') + // A and B should appear once each (not re-executed) + expect(completedIds.filter((id) => id === 'a')).toHaveLength(1) + expect(completedIds.filter((id) => id === 'b')).toHaveLength(1) + }) + + it('starts fresh when all nodes completed in the previous run', async () => { + const storage = new MockSnapshotStorage() + + const graph1 = new Graph({ + id: 'my-graph', + nodes: [makeAgent('a', 'a-reply'), makeAgent('b', 'b-reply')], + edges: [['a', 'b']], + plugins: [makeSessionManager(storage)], + }) + + const result1 = await graph1.invoke('start') + expect(result1.status).toBe(Status.COMPLETED) + + const graph2 = new Graph({ + id: 'my-graph', + nodes: [makeAgent('a', 'a-reply'), makeAgent('b', 'b-reply')], + edges: [['a', 'b']], + plugins: [makeSessionManager(storage)], + }) + + const result2 = await graph2.invoke('start') + + expect(result2.status).toBe(Status.COMPLETED) + // A should appear twice — once from restored state, once from fresh execution + const aCount = result2.results.filter((r) => r.nodeId === 'a' && r.status === Status.COMPLETED).length + expect(aCount).toBe(2) + }) + + it('respects conditional edges on resume', async () => { + const storage = new MockSnapshotStorage() + + // A → B (always), A → C (condition: false) + // First run: A completes, B completes, C blocked by condition + // maxSteps=2 allows A and B but graph completes normally since C is blocked + const graph1 = new Graph({ + id: 'my-graph', + nodes: [makeAgent('a', 'a-reply'), makeAgent('b', 'b-reply'), makeAgent('c', 'c-reply')], + edges: [ + { source: 'a', target: 'b', handler: () => true }, + { source: 'a', target: 'c', handler: () => false }, + ], + plugins: [makeSessionManager(storage)], + }) + + const result1 = await graph1.invoke('start') + expect(result1.results.map((r) => r.nodeId)).toStrictEqual(['a', 'b']) + + // Resume: C should still be blocked by the false condition + const graph2 = new Graph({ + id: 'my-graph', + nodes: [makeAgent('a', 'a-reply'), makeAgent('b', 'b-reply'), makeAgent('c', 'c-reply')], + edges: [ + { source: 'a', target: 'b', handler: () => true }, + { source: 'a', target: 'c', handler: () => false }, + ], + plugins: [makeSessionManager(storage)], + }) + + const result2 = await graph2.invoke('start') + + // C should not appear — condition still blocks it + const completedIds = result2.results.filter((r) => r.status === Status.COMPLETED).map((r) => r.nodeId) + expect(completedIds).not.toContain('c') + }) + }) +}) diff --git a/strands-ts/src/multiagent/__tests__/graph.tracer.test.ts b/strands-ts/src/multiagent/__tests__/graph.tracer.test.ts new file mode 100644 index 0000000000..3aab24987d --- /dev/null +++ b/strands-ts/src/multiagent/__tests__/graph.tracer.test.ts @@ -0,0 +1,283 @@ +import { describe, expect, it, vi, beforeEach, type MockInstance } from 'vitest' +import { Agent } from '../../agent/agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { TextBlock } from '../../types/messages.js' +import { Tracer } from '../../telemetry/tracer.js' +import { Graph } from '../graph.js' +import { BeforeNodeCallEvent } from '../events.js' +import { Status } from '../state.js' + +interface MockTracerInstance { + startAgentSpan: MockInstance + endAgentSpan: MockInstance + startAgentLoopSpan: MockInstance + endAgentLoopSpan: MockInstance + startModelInvokeSpan: MockInstance + endModelInvokeSpan: MockInstance + startToolCallSpan: MockInstance + endToolCallSpan: MockInstance + startMultiAgentSpan: MockInstance + endMultiAgentSpan: MockInstance + startNodeSpan: MockInstance + endNodeSpan: MockInstance + withSpanContext: MockInstance +} + +vi.mock('../../telemetry/tracer.js', () => ({ + Tracer: vi.fn(function () { + return { + startAgentSpan: vi.fn().mockReturnValue({ mock: 'agentSpan' }), + endAgentSpan: vi.fn(), + startAgentLoopSpan: vi.fn().mockReturnValue({ mock: 'loopSpan' }), + endAgentLoopSpan: vi.fn(), + startModelInvokeSpan: vi.fn().mockReturnValue({ mock: 'modelSpan' }), + endModelInvokeSpan: vi.fn(), + startToolCallSpan: vi.fn().mockReturnValue({ mock: 'toolSpan' }), + endToolCallSpan: vi.fn(), + startMultiAgentSpan: vi.fn().mockReturnValue({ mock: 'multiAgentSpan' }), + endMultiAgentSpan: vi.fn(), + startNodeSpan: vi.fn().mockReturnValue({ mock: 'nodeSpan' }), + endNodeSpan: vi.fn(), + withSpanContext: vi.fn((_span: unknown, fn: () => unknown) => fn()), + } + }), +})) + +/** + * Returns the Tracer mock instance owned by the Graph. + * Agents are constructed before the Graph, so the Graph's Tracer + * is always the last one created during Graph construction. + */ +function getGraphTracer(): MockTracerInstance { + return vi.mocked(Tracer).mock.results.at(-1)!.value +} + +function makeAgent(id: string, text = 'reply'): Agent { + const model = new MockMessageModel().addTurn(new TextBlock(text)) + return new Agent({ model, printer: false, id }) +} + +function makeAgentWithUsage(id: string, text = 'reply'): Agent { + const model = new MockMessageModel().addTurn(new TextBlock(text), { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + }) + return new Agent({ model, printer: false, id }) +} + +describe('Graph tracer integration', () => { + let graph: Graph + let tracer: MockTracerInstance + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('multi-agent span lifecycle', () => { + it('starts and ends multi-agent span on successful invocation', async () => { + graph = new Graph({ id: 'test-graph', nodes: [makeAgent('a')], edges: [] }) + tracer = getGraphTracer() + + await graph.invoke('Hello') + + expect(tracer.startMultiAgentSpan.mock.calls).toEqual([ + [{ orchestratorId: 'test-graph', orchestratorType: 'graph', input: 'Hello' }], + ]) + expect(tracer.endMultiAgentSpan.mock.calls.length).toBe(1) + + const [span, endOpts] = tracer.endMultiAgentSpan.mock.calls[0]! + expect(span).toStrictEqual({ mock: 'multiAgentSpan' }) + expect(endOpts).toEqual({ + duration: expect.any(Number), + usage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 }, + }) + expect(endOpts.duration).toBeGreaterThanOrEqual(0) + }) + + it('passes exact usage from result to endMultiAgentSpan', async () => { + graph = new Graph({ id: 'test-graph', nodes: [makeAgentWithUsage('a')], edges: [] }) + tracer = getGraphTracer() + + await graph.invoke('Hello') + + const [, endOpts] = tracer.endMultiAgentSpan.mock.calls[0]! + expect(endOpts.usage).toStrictEqual({ inputTokens: 10, outputTokens: 5, totalTokens: 15 }) + }) + + it('ends multi-agent span with error when maxSteps exceeded', async () => { + graph = new Graph({ + nodes: [makeAgent('a'), makeAgent('b')], + edges: [['a', 'b']], + maxSteps: 1, + }) + tracer = getGraphTracer() + + await expect(graph.invoke('Hello')).rejects.toThrow('max steps reached') + + const [span, endOpts] = tracer.endMultiAgentSpan.mock.calls[0]! + expect(span).toStrictEqual({ mock: 'multiAgentSpan' }) + expect(endOpts).toEqual({ + duration: expect.any(Number), + error: expect.objectContaining({ + message: expect.stringContaining('max steps reached'), + }), + }) + expect(endOpts.duration).toBeGreaterThanOrEqual(0) + }) + }) + + describe('node span lifecycle', () => { + it('starts and ends node span for each node execution', async () => { + graph = new Graph({ nodes: [makeAgent('a'), makeAgent('b')], edges: [['a', 'b']] }) + tracer = getGraphTracer() + + await graph.invoke('Hello') + + expect(tracer.startNodeSpan.mock.calls).toEqual([ + [{ nodeId: 'a', nodeType: 'agentNode' }], + [{ nodeId: 'b', nodeType: 'agentNode' }], + ]) + expect(tracer.endNodeSpan.mock.calls.length).toBe(2) + }) + + it('ends node span with COMPLETED status, duration, and zero usage on success', async () => { + graph = new Graph({ nodes: [makeAgent('a')], edges: [] }) + tracer = getGraphTracer() + + await graph.invoke('Hello') + + const [span, endOpts] = tracer.endNodeSpan.mock.calls[0]! + expect(span).toStrictEqual({ mock: 'nodeSpan' }) + expect(endOpts).toEqual({ + status: Status.COMPLETED, + duration: expect.any(Number), + usage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 }, + }) + expect(endOpts.duration).toBeGreaterThanOrEqual(0) + }) + + it('passes exact usage from node result to endNodeSpan', async () => { + graph = new Graph({ nodes: [makeAgentWithUsage('a')], edges: [] }) + tracer = getGraphTracer() + + await graph.invoke('Hello') + + const [, endOpts] = tracer.endNodeSpan.mock.calls[0]! + expect(endOpts.status).toBe(Status.COMPLETED) + expect(endOpts.usage).toStrictEqual({ inputTokens: 10, outputTokens: 5, totalTokens: 15 }) + }) + + it('ends node span with FAILED status when node agent throws', async () => { + const model = new MockMessageModel().addTurn(new Error('agent exploded')) + graph = new Graph({ nodes: [new Agent({ model, printer: false, id: 'a' })], edges: [] }) + tracer = getGraphTracer() + + const result = await graph.invoke('Hello') + + expect(result.status).toBe(Status.FAILED) + const [span, endOpts] = tracer.endNodeSpan.mock.calls[0]! + expect(span).toStrictEqual({ mock: 'nodeSpan' }) + expect(endOpts).toEqual({ + status: Status.FAILED, + duration: expect.any(Number), + }) + expect(endOpts.duration).toBeGreaterThanOrEqual(0) + }) + + it('ends node span with CANCELLED status and zero duration when cancelled by hook', async () => { + graph = new Graph({ nodes: [makeAgent('a')], edges: [] }) + tracer = getGraphTracer() + graph.addHook(BeforeNodeCallEvent, (event) => { + event.cancel = 'cancelled by test' + }) + + await graph.invoke('Hello') + + expect(tracer.endNodeSpan.mock.calls).toEqual([[{ mock: 'nodeSpan' }, { status: Status.CANCELLED, duration: 0 }]]) + }) + + it('ends node span with INTERRUPTED status when a hook raises an interrupt', async () => { + graph = new Graph({ nodes: [makeAgent('a')], edges: [] }) + tracer = getGraphTracer() + graph.addHook(BeforeNodeCallEvent, (event) => { + event.interrupt({ name: 'gate', reason: 'approve?' }) + }) + + const result = await graph.invoke('Hello') + + expect(result.status).toBe(Status.INTERRUPTED) + expect(tracer.endNodeSpan).toHaveBeenCalledTimes(1) + const [span, endArgs] = tracer.endNodeSpan.mock.calls[0]! + expect(span).toEqual({ mock: 'nodeSpan' }) + expect(endArgs.status).toBe(Status.INTERRUPTED) + expect(typeof endArgs.duration).toBe('number') + }) + }) + + describe('null span handling', () => { + it('completes successfully when startMultiAgentSpan returns null', async () => { + graph = new Graph({ nodes: [makeAgent('a')], edges: [] }) + tracer = getGraphTracer() + tracer.startMultiAgentSpan.mockReturnValue(null) + + const result = await graph.invoke('Hello') + + expect(result.status).toBe(Status.COMPLETED) + const [span] = tracer.endMultiAgentSpan.mock.calls[0]! + expect(span).toBeNull() + }) + + it('completes successfully when startNodeSpan returns null', async () => { + graph = new Graph({ nodes: [makeAgent('a')], edges: [] }) + tracer = getGraphTracer() + tracer.startNodeSpan.mockReturnValue(null) + + const result = await graph.invoke('Hello') + + expect(result.status).toBe(Status.COMPLETED) + const [span] = tracer.endNodeSpan.mock.calls[0]! + expect(span).toBeNull() + }) + }) + + describe('span context propagation', () => { + it('passes node span to every withSpanContext call during node execution', async () => { + graph = new Graph({ nodes: [makeAgent('a')], edges: [] }) + tracer = getGraphTracer() + + await graph.invoke('Hello') + + // First call: multiAgentSpan to create nodeSpan, then nodeSpan for node.stream() + gen.next() calls + const calls = tracer.withSpanContext.mock.calls + expect(calls.length).toBeGreaterThanOrEqual(3) + + // First call uses multiAgentSpan to create the nodeSpan + expect(calls[0]).toEqual([{ mock: 'multiAgentSpan' }, expect.any(Function)]) + + // Subsequent calls use nodeSpan for node execution + const subsequentCalls = calls.slice(1) + expect(subsequentCalls).toEqual( + expect.arrayContaining(Array(subsequentCalls.length).fill([{ mock: 'nodeSpan' }, expect.any(Function)])) + ) + }) + }) + + describe('parallel node execution', () => { + it('creates separate node spans for parallel source nodes', async () => { + graph = new Graph({ + nodes: [makeAgent('a'), makeAgent('b'), makeAgent('c')], + edges: [ + ['a', 'c'], + ['b', 'c'], + ], + }) + tracer = getGraphTracer() + + await graph.invoke('Hello') + + const nodeIds = tracer.startNodeSpan.mock.calls.map((call) => call[0].nodeId) + expect(nodeIds).toEqual(expect.arrayContaining(['a', 'b', 'c'])) + expect(tracer.startNodeSpan.mock.calls.length).toBe(3) + expect(tracer.endNodeSpan.mock.calls.length).toBe(3) + }) + }) +}) diff --git a/strands-ts/src/multiagent/__tests__/interrupts.test.ts b/strands-ts/src/multiagent/__tests__/interrupts.test.ts new file mode 100644 index 0000000000..f741a71b9c --- /dev/null +++ b/strands-ts/src/multiagent/__tests__/interrupts.test.ts @@ -0,0 +1,441 @@ +import { describe, expect, it } from 'vitest' +import { Agent } from '../../agent/agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { MockSnapshotStorage } from '../../__fixtures__/mock-storage-provider.js' +import { createCancellableAgent } from '../../__fixtures__/agent-helpers.js' +import { createMockTool } from '../../__fixtures__/tool-helpers.js' +import { InterruptResponseContent } from '../../types/interrupt.js' +import { Graph } from '../graph.js' +import { Swarm } from '../swarm.js' +import { Status } from '../state.js' +import { SessionManager } from '../../session/session-manager.js' +import { BeforeNodeCallEvent } from '../events.js' +import { TextBlock } from '../../types/messages.js' + +/** + * Interrupt round-trip tests. Verifies that an orchestrator can hit an interrupt, + * persist enough state via a SessionManager to let a later invocation resume, and + * produce a clean terminal result once all interrupts are answered. + * + * Each run uses a fresh agent instance so session-driven state restoration is what + * wires resume together — just like a real cross-process resume. + */ + +function makeSessionManager(storage: MockSnapshotStorage): SessionManager { + return new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + }) +} + +/** Tool that interrupts once, then returns a static value on resume. */ +function interruptingTool(name: string, interruptName: string, resumeValue = 'ok') { + return createMockTool(name, (context) => { + context.interrupt({ name: interruptName, reason: `need ${interruptName}` }) + return resumeValue + }) +} + +describe('Multi-agent interrupts: round-trip', () => { + it('Graph: agent interrupts, resumes via top-level SessionManager', async () => { + const storage = new MockSnapshotStorage() + const tool = interruptingTool('confirmTool', 'confirm', 'approved') + + // Agent's model for run 1 returns a tool use (which interrupts). + const modelRun1 = new MockMessageModel().addTurn({ + type: 'toolUseBlock', + name: 'confirmTool', + toolUseId: 'tool-1', + input: {}, + }) + const agent1 = new Agent({ model: modelRun1, tools: [tool], printer: false, id: 'a' }) + const graph1 = new Graph({ + nodes: [agent1], + edges: [], + sessionManager: makeSessionManager(storage), + }) + + const interruptResult = await graph1.invoke('go') + expect(interruptResult.status).toBe(Status.INTERRUPTED) + expect(interruptResult.interrupts).toHaveLength(1) + + // Run 2's model provides the final text turn plus a trailing turn that should + // never be consumed. Two turns are needed so the mock model's callCount tracks + // (single-turn mode has a quirk where callCount stays at 0 regardless of calls). + // If the resumed agent replayed the pending tool use correctly, it calls the + // model exactly once (for the post-tool turn) — NOT twice (which would mean the + // tool use was re-fetched from the model instead of replayed). + const modelRun2 = new MockMessageModel() + .addTurn({ type: 'textBlock', text: 'done' }) + .addTurn({ type: 'textBlock', text: 'unreachable' }) + const agent2 = new Agent({ model: modelRun2, tools: [tool], printer: false, id: 'a' }) + const graph2 = new Graph({ + nodes: [agent2], + edges: [], + sessionManager: makeSessionManager(storage), + }) + + const response = new InterruptResponseContent({ + interruptId: interruptResult.interrupts![0]!.id, + response: 'yes', + }) + const finalResult = await graph2.invoke([response]) + + expect(finalResult.status).toBe(Status.COMPLETED) + expect(finalResult.interrupts).toBeUndefined() + for (const result of finalResult.results) { + expect(result.interrupts).toBeUndefined() + } + // Model called exactly once on resume — for the post-tool turn. The pending + // tool use came from the restored snapshot, not a re-fetch. + expect(modelRun2.callCount).toBe(1) + }) + + it('Swarm: agent interrupts, resumes via top-level SessionManager', async () => { + const storage = new MockSnapshotStorage() + const tool = interruptingTool('confirmTool', 'confirm_a', 'resumed') + + const modelRun1 = new MockMessageModel().addTurn({ + type: 'toolUseBlock', + name: 'confirmTool', + toolUseId: 'tool-A', + input: {}, + }) + const agent1 = new Agent({ model: modelRun1, tools: [tool], printer: false, id: 'a' }) + const swarm1 = new Swarm({ + nodes: [agent1], + start: 'a', + sessionManager: makeSessionManager(storage), + }) + const interruptResult = await swarm1.invoke('start') + expect(interruptResult.status).toBe(Status.INTERRUPTED) + expect(interruptResult.interrupts).toHaveLength(1) + + // Swarm uses structured output for handoffs — the final (non-handoff) turn + // terminates execution. + const modelRun2 = new MockMessageModel().addTurn({ + type: 'toolUseBlock', + name: 'strands_structured_output', + toolUseId: 'so-1', + input: { message: 'all done' }, + }) + const agent2 = new Agent({ model: modelRun2, tools: [tool], printer: false, id: 'a' }) + const swarm2 = new Swarm({ + nodes: [agent2], + start: 'a', + sessionManager: makeSessionManager(storage), + }) + const response = new InterruptResponseContent({ + interruptId: interruptResult.interrupts![0]!.id, + response: 'ok', + }) + const finalResult = await swarm2.invoke([response]) + + expect(finalResult.status).toBe(Status.COMPLETED) + }) + + it('Graph parallel: interrupt on one branch lets in-flight sibling finish', async () => { + const tool = interruptingTool('confirmTool', 'confirm', 'approved') + + // Source node 'start' runs quickly and produces two parallel branches. + // Branch 'interrupter' interrupts immediately. Branch 'sibling' takes a moment + // to complete. The interrupt does not abort siblings — they run to completion + // and the aggregate result carries both outcomes. + const startModel = new MockMessageModel().addTurn({ type: 'textBlock', text: 'go' }) + const start = new Agent({ model: startModel, printer: false, id: 'start' }) + + const interrupterModel = new MockMessageModel().addTurn({ + type: 'toolUseBlock', + name: 'confirmTool', + toolUseId: 'tool-i', + input: {}, + }) + const interrupter = new Agent({ model: interrupterModel, tools: [tool], printer: false, id: 'interrupter' }) + + const sibling = createCancellableAgent('sibling', 50) + + const graph = new Graph({ + nodes: [start, interrupter, sibling], + edges: [ + ['start', 'interrupter'], + ['start', 'sibling'], + ], + timeout: 5_000, + }) + + const result = await graph.invoke('begin') + + // Aggregate status surfaces INTERRUPTED (the actionable state) — `_resolveStatus` + // ranks INTERRUPTED above COMPLETED. + expect(result.status).toBe(Status.INTERRUPTED) + + const siblingResult = result.results.find((r) => r.nodeId === 'sibling') + expect(siblingResult?.status).toBe(Status.COMPLETED) + + const interrupterResult = result.results.find((r) => r.nodeId === 'interrupter') + expect(interrupterResult?.status).toBe(Status.INTERRUPTED) + expect(interrupterResult?.interrupts).toHaveLength(1) + }) + + it('Nested orchestrator: interrupts bubble up on first run but do not round-trip without a nested SessionManager', async () => { + // Nested orchestrator has no SessionManager of its own, only the outer one does. + // First run works (interrupt bubbles up through MultiAgentNode into outer result). + // Second run FAILS at routing because the nested state was never persisted: the + // nested Swarm's NodeState.interrupts is empty on rehydrate, so the response id + // has no home. This test pins down the documented limitation. + const storage = new MockSnapshotStorage() + const tool = interruptingTool('confirmTool', 'confirm_nested', 'ok') + + const buildInner = (): Swarm => { + const model = new MockMessageModel().addTurn({ + type: 'toolUseBlock', + name: 'confirmTool', + toolUseId: 'tool-n', + input: {}, + }) + const agent = new Agent({ model, tools: [tool], printer: false, id: 'inner-agent' }) + return new Swarm({ nodes: [agent], start: 'inner-agent', id: 'inner' }) + } + + const outer1 = new Graph({ + nodes: [{ orchestrator: buildInner() }], + edges: [], + sessionManager: makeSessionManager(storage), + }) + const interruptResult = await outer1.invoke('go') + expect(interruptResult.status).toBe(Status.INTERRUPTED) + expect(interruptResult.interrupts).toHaveLength(1) + + const outer2 = new Graph({ + nodes: [{ orchestrator: buildInner() }], + edges: [], + sessionManager: makeSessionManager(storage), + }) + const response = new InterruptResponseContent({ + interruptId: interruptResult.interrupts![0]!.id, + response: 'yes', + }) + + // Routing at the outer level finds the MultiAgentNode. Inside, the nested + // Swarm creates a fresh MultiAgentState; no nested NodeState.interrupts match + // the response id, so groupInterruptResponsesByNode throws. `Node.stream` catches + // the error and produces a FAILED result for the nested node. The limitation is + // diagnosable via the error message on that node's result, just not transparent. + const finalResult = await outer2.invoke([response]) + const innerNode = finalResult.results.find((r) => r.nodeId === 'inner') + expect(innerNode?.status).toBe(Status.FAILED) + expect(innerNode?.error?.message).toMatch(/no node found with matching interrupt/) + }) + + it('Graph: BeforeNodeCallEvent.interrupt gates a node before it runs, resumes via SessionManager', async () => { + const storage = new MockSnapshotStorage() + + // The gated node has a normal agent — the interrupt fires BEFORE the node runs + // via an orchestrator hook, not from inside the agent. + const buildAgent = () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'executed' }) + return new Agent({ model, printer: false, id: 'execute' }) + } + + const graph1 = new Graph({ + nodes: [buildAgent()], + edges: [], + sessionManager: makeSessionManager(storage), + }) + graph1.addHook(BeforeNodeCallEvent, (event) => { + if (event.nodeId === 'execute') { + event.interrupt({ name: 'node_approval', reason: 'approve?' }) + } + }) + + const interruptResult = await graph1.invoke('begin') + expect(interruptResult.status).toBe(Status.INTERRUPTED) + expect(interruptResult.interrupts).toHaveLength(1) + expect(interruptResult.interrupts![0]!.source).toBe('multiagent-hook') + + // Resume with approval. Hook runs again, sees the stored response, returns it + // without throwing. Node proceeds to execute. + const graph2 = new Graph({ + nodes: [buildAgent()], + edges: [], + sessionManager: makeSessionManager(storage), + }) + graph2.addHook(BeforeNodeCallEvent, (event) => { + if (event.nodeId === 'execute') { + const response = event.interrupt<{ approved: boolean }>({ name: 'node_approval', reason: 'approve?' }) + if (!response.approved) { + event.cancel = 'not approved' + } + } + }) + + const response = new InterruptResponseContent({ + interruptId: interruptResult.interrupts![0]!.id, + response: { approved: true }, + }) + const finalResult = await graph2.invoke([response]) + + expect(finalResult.status).toBe(Status.COMPLETED) + const executedNode = finalResult.results.find((r) => r.nodeId === 'execute') + expect(executedNode?.status).toBe(Status.COMPLETED) + expect(executedNode?.content.some((b) => b instanceof TextBlock && b.text === 'executed')).toBe(true) + }) + + it('Swarm: BeforeNodeCallEvent.interrupt gates a node before it runs, resumes via SessionManager', async () => { + const storage = new MockSnapshotStorage() + const buildAgent = () => { + const model = new MockMessageModel().addTurn({ + type: 'toolUseBlock', + name: 'strands_structured_output', + toolUseId: 'so-1', + input: { message: 'ran' }, + }) + return new Agent({ model, printer: false, id: 'a' }) + } + + const swarm1 = new Swarm({ + nodes: [buildAgent()], + start: 'a', + sessionManager: makeSessionManager(storage), + }) + swarm1.addHook(BeforeNodeCallEvent, (event) => { + event.interrupt({ name: 'gate', reason: 'approve?' }) + }) + + const interruptResult = await swarm1.invoke('begin') + expect(interruptResult.status).toBe(Status.INTERRUPTED) + expect(interruptResult.interrupts![0]!.source).toBe('multiagent-hook') + + const swarm2 = new Swarm({ + nodes: [buildAgent()], + start: 'a', + sessionManager: makeSessionManager(storage), + }) + swarm2.addHook(BeforeNodeCallEvent, (event) => { + event.interrupt({ name: 'gate', reason: 'approve?' }) + }) + + const response = new InterruptResponseContent({ + interruptId: interruptResult.interrupts![0]!.id, + response: 'approved', + }) + const finalResult = await swarm2.invoke([response]) + expect(finalResult.status).toBe(Status.COMPLETED) + }) + + it('Graph: hook gate + tool interrupt across successive runs, each layer resumed in turn', async () => { + // Run 1: orchestrator hook gates the node (hook interrupt, source=multiagent-hook). + // Run 2: hook approves on resume, node runs, tool interrupts (source=tool). + // Run 3: tool resumes, agent completes. + // Exercises both interrupt layers in the same graph, with proper layer routing + // via applyOrchestratorHookResponses. + const storage = new MockSnapshotStorage() + const tool = interruptingTool('toolInterrupt', 'tool_confirm', 'done') + + // Each run uses a fresh agent whose model provides only the turns it needs. + const buildGraph = (modelTurns: 'toolUse' | 'text'): Graph => { + const model = new MockMessageModel() + if (modelTurns === 'toolUse') { + model.addTurn({ type: 'toolUseBlock', name: 'toolInterrupt', toolUseId: 'tool-1', input: {} }) + } else { + model.addTurn({ type: 'textBlock', text: 'done' }).addTurn({ type: 'textBlock', text: 'unreachable' }) + } + const agent = new Agent({ model, tools: [tool], printer: false, id: 'a' }) + const graph = new Graph({ + nodes: [agent], + edges: [], + sessionManager: makeSessionManager(storage), + }) + graph.addHook(BeforeNodeCallEvent, (event) => { + event.interrupt({ name: 'hook_gate', reason: 'approve node?' }) + }) + return graph + } + + const run1 = await buildGraph('toolUse').invoke('begin') + expect(run1.status).toBe(Status.INTERRUPTED) + expect(run1.interrupts![0]!.source).toBe('multiagent-hook') + + const hookResponse = new InterruptResponseContent({ + interruptId: run1.interrupts![0]!.id, + response: { approved: true }, + }) + const run2 = await buildGraph('toolUse').invoke([hookResponse]) + expect(run2.status).toBe(Status.INTERRUPTED) + expect(run2.interrupts![0]!.source).toBe('tool') + }) + + it('Graph: hook-gated node still emits NodeResultEvent and AfterNodeCallEvent', async () => { + // Lifecycle observers (SessionManager per-node save, metrics, tracing) rely on + // each node terminating with the same event pair regardless of HOW it terminated. + const agent = new Agent({ model: new MockMessageModel(), printer: false, id: 'gated' }) + const graph = new Graph({ nodes: [agent], edges: [] }) + graph.addHook(BeforeNodeCallEvent, (event) => { + event.interrupt({ name: 'gate', reason: 'approve?' }) + }) + + const eventTypes: string[] = [] + for await (const event of graph.stream('hi')) { + eventTypes.push(event.type) + } + + expect(eventTypes).toContain('beforeNodeCallEvent') + expect(eventTypes).toContain('nodeResultEvent') + expect(eventTypes).toContain('afterNodeCallEvent') + // Strict ordering: after comes after result, which comes after before. + expect(eventTypes.indexOf('beforeNodeCallEvent')).toBeLessThan(eventTypes.indexOf('nodeResultEvent')) + expect(eventTypes.indexOf('nodeResultEvent')).toBeLessThan(eventTypes.indexOf('afterNodeCallEvent')) + }) + + it('Swarm: hook-gated node still emits NodeResultEvent and AfterNodeCallEvent', async () => { + const agent = new Agent({ model: new MockMessageModel(), printer: false, id: 'a' }) + const swarm = new Swarm({ nodes: [agent], start: 'a' }) + swarm.addHook(BeforeNodeCallEvent, (event) => { + event.interrupt({ name: 'gate', reason: 'approve?' }) + }) + + const eventTypes: string[] = [] + for await (const event of swarm.stream('hi')) { + eventTypes.push(event.type) + } + + expect(eventTypes).toContain('beforeNodeCallEvent') + expect(eventTypes).toContain('nodeResultEvent') + expect(eventTypes).toContain('afterNodeCallEvent') + expect(eventTypes.indexOf('beforeNodeCallEvent')).toBeLessThan(eventTypes.indexOf('nodeResultEvent')) + expect(eventTypes.indexOf('nodeResultEvent')).toBeLessThan(eventTypes.indexOf('afterNodeCallEvent')) + }) + + it('Graph: resume against a graph whose topology changed throws a descriptive error', async () => { + // Simulate a save/restore where the reconstructed graph is missing a node that + // had an outstanding interrupt in the saved state. The routing lookup should fail + // loudly rather than silently (which would previously have crashed on a non-null + // assertion with an unhelpful TypeError). + const storage = new MockSnapshotStorage() + const tool = interruptingTool('confirmTool', 'confirm_top', 'ok') + + const model1 = new MockMessageModel().addTurn({ + type: 'toolUseBlock', + name: 'confirmTool', + toolUseId: 'tool-topo', + input: {}, + }) + const agent1 = new Agent({ model: model1, tools: [tool], printer: false, id: 'will-vanish' }) + const graph1 = new Graph({ nodes: [agent1], edges: [], sessionManager: makeSessionManager(storage) }) + const interruptResult = await graph1.invoke('go') + expect(interruptResult.status).toBe(Status.INTERRUPTED) + + const differentAgent = new Agent({ + model: new MockMessageModel(), + printer: false, + id: 'different-node', + }) + const graph2 = new Graph({ nodes: [differentAgent], edges: [], sessionManager: makeSessionManager(storage) }) + const response = new InterruptResponseContent({ + interruptId: interruptResult.interrupts![0]!.id, + response: 'yes', + }) + + await expect(graph2.invoke([response])).rejects.toThrow(/topology changed between save and resume/) + }) +}) diff --git a/strands-ts/src/multiagent/__tests__/nodes.test.ts b/strands-ts/src/multiagent/__tests__/nodes.test.ts new file mode 100644 index 0000000000..e094109d71 --- /dev/null +++ b/strands-ts/src/multiagent/__tests__/nodes.test.ts @@ -0,0 +1,405 @@ +import { beforeEach, describe, expect, it } from 'vitest' +import { z } from 'zod' +import { Agent } from '../../agent/agent.js' +import { BeforeInvocationEvent } from '../../hooks/events.js' +import type { MultiAgentInput } from '../multiagent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { collectGenerator } from '../../__fixtures__/model-test-helpers.js' +import { TextBlock } from '../../types/messages.js' +import { MultiAgentResult, MultiAgentState, NodeResult, Status } from '../state.js' +import type { MultiAgentStreamEvent } from '../events.js' +import { MultiAgentHandoffEvent, NodeStreamUpdateEvent } from '../events.js' +import { AgentNode, MultiAgentNode, Node } from '../nodes.js' +import type { MultiAgent } from '../multiagent.js' +import type { NodeResultUpdate } from '../state.js' + +/** + * Concrete Node subclass for testing the abstract base class. + */ +class TestNode extends Node { + private readonly _fn: ( + args: MultiAgentInput, + state: MultiAgentState + ) => AsyncGenerator + + constructor( + id: string, + fn: ( + args: MultiAgentInput, + state: MultiAgentState + ) => AsyncGenerator + ) { + super(id, {}) + this._fn = fn + } + + async *handle( + args: MultiAgentInput, + state: MultiAgentState + ): AsyncGenerator { + return yield* this._fn(args, state) + } +} + +describe('Node', () => { + let state: MultiAgentState + + beforeEach(() => { + state = new MultiAgentState({ nodeIds: ['test-node', 'fail-node'] }) + }) + + describe('stream', () => { + it('returns COMPLETED NodeResult on successful execution', async () => { + const content = [new TextBlock('result')] + const node = new TestNode('test-node', async function* () { + yield* [] + return { content } + }) + + const { items, result } = await collectGenerator(node.stream([], state)) + + const resultEvent = items.find((e) => e.type === 'nodeResultEvent') + expect(resultEvent).toEqual({ + type: 'nodeResultEvent', + nodeId: 'test-node', + nodeType: 'node', + state, + result, + invocationState: {}, + }) + + expect(result).toEqual({ + type: 'nodeResult', + nodeId: 'test-node', + status: Status.COMPLETED, + content, + duration: expect.any(Number), + }) + }) + + it('catches errors and returns FAILED NodeResult', async () => { + const node = new TestNode('fail-node', async function* () { + yield* [] + throw new Error('boom') + }) + + const { items, result } = await collectGenerator(node.stream([], state)) + + const resultEvent = items.find((e) => e.type === 'nodeResultEvent') + expect(resultEvent).toEqual({ + type: 'nodeResultEvent', + nodeId: 'fail-node', + nodeType: 'node', + state, + result, + invocationState: {}, + }) + + expect(result).toEqual({ + type: 'nodeResult', + nodeId: 'fail-node', + status: Status.FAILED, + content: [], + duration: expect.any(Number), + error: expect.objectContaining({ message: 'boom' }), + }) + }) + }) +}) + +describe('AgentNode', () => { + let agent: Agent + let node: AgentNode + let state: MultiAgentState + + beforeEach(() => { + const model = new MockMessageModel().addTurn(new TextBlock('reply')) + agent = new Agent({ model, printer: false, appState: { key1: 'value1' }, id: 'agent-1' }) + node = new AgentNode({ agent }) + state = new MultiAgentState({ nodeIds: ['agent-1'] }) + }) + + describe('constructor', () => { + it('throws when timeout < 1', () => { + expect(() => new AgentNode({ agent, timeout: 0 })).toThrow('timeout=<0>, node_id= | must be at least 1') + }) + + it('accepts a positive timeout', () => { + const timedNode = new AgentNode({ agent, timeout: 5_000 }) + expect(timedNode.timeout).toBe(5_000) + }) + + it('accepts Infinity as an explicit opt-out', () => { + const timedNode = new AgentNode({ agent, timeout: Infinity }) + expect(timedNode.timeout).toBe(Infinity) + }) + + it('defaults preserveContext to false', () => { + expect(node.preserveContext).toBe(false) + }) + + it('stores the preserveContext flag when provided', () => { + const preserveContextNode = new AgentNode({ agent, preserveContext: true }) + expect(preserveContextNode.preserveContext).toBe(true) + }) + + it('throws when preserveContext is set with a non-Agent InvokableAgent', () => { + const customAgent = { + id: 'custom', + async invoke() { + throw new Error('not used') + }, + // eslint-disable-next-line require-yield + async *stream() { + throw new Error('not used') + }, + addHook() { + return () => {} + }, + } + expect(() => new AgentNode({ agent: customAgent, preserveContext: true })).toThrow( + /preserveContext=true requires an Agent/ + ) + }) + }) + + describe('handle', () => { + it('wraps agent events and returns content', async () => { + const { items, result } = await collectGenerator(node.stream([new TextBlock('prompt')], state)) + + const streamEvents = items.filter((e) => e.type === 'nodeStreamUpdateEvent') + expect(streamEvents.length).toBeGreaterThan(0) + for (const event of streamEvents) { + expect(event).toEqual( + expect.objectContaining({ type: 'nodeStreamUpdateEvent', nodeId: 'agent-1', nodeType: 'agentNode' }) + ) + } + + const resultEvent = items.find((e) => e.type === 'nodeResultEvent') + expect(resultEvent).toEqual( + expect.objectContaining({ type: 'nodeResultEvent', nodeId: 'agent-1', nodeType: 'agentNode', result }) + ) + + expect(result).toEqual({ + type: 'nodeResult', + nodeId: 'agent-1', + status: Status.COMPLETED, + content: expect.arrayContaining([expect.objectContaining({ type: 'textBlock', text: 'reply' })]), + duration: expect.any(Number), + usage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 }, + }) + }) + + it('restores agent messages and state after execution', async () => { + const messagesBefore = [...agent.messages] + const stateBefore = agent.appState.getAll() + + await collectGenerator(node.stream([new TextBlock('prompt')], state)) + + expect(agent.messages).toStrictEqual(messagesBefore) + expect(agent.appState.getAll()).toStrictEqual(stateBefore) + }) + + it('retains agent messages across executions when preserveContext is true', async () => { + const model = new MockMessageModel().addTurn(new TextBlock('reply-1')).addTurn(new TextBlock('reply-2')) + const preserveContextAgent = new Agent({ model, printer: false, id: 'preserve-context-agent' }) + const preserveContextNode = new AgentNode({ agent: preserveContextAgent, preserveContext: true }) + const preserveContextState = new MultiAgentState({ nodeIds: ['preserve-context-agent'] }) + + await collectGenerator(preserveContextNode.stream([new TextBlock('first')], preserveContextState)) + const messagesAfterFirst = preserveContextAgent.messages.length + expect(messagesAfterFirst).toBeGreaterThan(0) + + await collectGenerator(preserveContextNode.stream([new TextBlock('second')], preserveContextState)) + + expect(preserveContextAgent.messages.length).toBeGreaterThan(messagesAfterFirst) + }) + + it('retains appState mutations across executions when preserveContext is true', async () => { + const model = new MockMessageModel().addTurn(new TextBlock('reply-1')).addTurn(new TextBlock('reply-2')) + const preserveContextAgent = new Agent({ model, printer: false, id: 'preserve-context-agent' }) + // Hook bumps a counter on appState every time the agent is invoked. + preserveContextAgent.addHook(BeforeInvocationEvent, (event) => { + const count = event.agent.appState.get<{ count: number }>('count') ?? 0 + event.agent.appState.set('count', count + 1) + }) + const preserveContextNode = new AgentNode({ agent: preserveContextAgent, preserveContext: true }) + const preserveContextState = new MultiAgentState({ nodeIds: ['preserve-context-agent'] }) + + await collectGenerator(preserveContextNode.stream([new TextBlock('first')], preserveContextState)) + expect(preserveContextAgent.appState.get<{ count: number }>('count')).toBe(1) + + await collectGenerator(preserveContextNode.stream([new TextBlock('second')], preserveContextState)) + expect(preserveContextAgent.appState.get<{ count: number }>('count')).toBe(2) + }) + + it('passes structuredOutputSchema from options to the agent', async () => { + const schema = z.object({ agentName: z.string().optional(), message: z.string() }) + + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'strands_structured_output', + toolUseId: 'tool-1', + input: { message: 'hello' }, + }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + agent = new Agent({ model, printer: false, id: 'schema-agent' }) + node = new AgentNode({ agent }) + state = new MultiAgentState({ nodeIds: ['schema-agent'] }) + + const { result } = await collectGenerator(node.stream('test', state, { structuredOutputSchema: schema })) + + expect(result.structuredOutput).toStrictEqual({ message: 'hello' }) + }) + }) + + describe('agent', () => { + it('exposes the wrapped agent instance', () => { + expect(node.agent).toBe(agent) + }) + }) +}) + +describe('MultiAgentNode', () => { + const content = [new TextBlock('inner-result')] + + /** + * Creates a mock orchestrator that yields the given events and returns a result with the given content. + */ + function mockOrchestrator(id: string, events: MultiAgentStreamEvent[]): MultiAgent { + return { + id, + invoke: async () => new MultiAgentResult({ results: [], duration: 0 }), + async *stream() { + for (const event of events) { + yield event + } + return new MultiAgentResult({ + results: [new NodeResult({ nodeId: id, status: Status.COMPLETED, duration: 0, content })], + content, + duration: 0, + }) + }, + addHook: () => () => {}, + } + } + + let node: MultiAgentNode + let state: MultiAgentState + + beforeEach(() => { + const orchestrator = mockOrchestrator('inner', []) + node = new MultiAgentNode({ orchestrator }) + state = new MultiAgentState({ nodeIds: ['inner'] }) + }) + + describe('constructor', () => { + it('derives id from orchestrator', () => { + expect(node.id).toBe('inner') + }) + }) + + describe('handle', () => { + it('passes through inner NodeStreamUpdateEvents', async () => { + const innerUpdate = new MultiAgentHandoffEvent({ source: 'x', targets: ['y'], state, invocationState: {} }) + const innerEvent = new NodeStreamUpdateEvent({ + nodeId: 'deep-node', + nodeType: 'agentNode', + state, + inner: { source: 'multiAgent', event: innerUpdate }, + invocationState: {}, + }) + const orchestrator = mockOrchestrator('inner', [innerEvent]) + node = new MultiAgentNode({ orchestrator }) + + const { items } = await collectGenerator(node.stream([], state)) + + const streamEvents = items.filter((e) => e.type === 'nodeStreamUpdateEvent') as NodeStreamUpdateEvent[] + const passthrough = streamEvents.find((e) => e.nodeId === 'deep-node') + expect(passthrough).toBe(innerEvent) + }) + + it('wraps non-NodeStreamUpdateEvents with this node identity', async () => { + const handoff = new MultiAgentHandoffEvent({ source: 'a', targets: ['b'], state, invocationState: {} }) + const orchestrator = mockOrchestrator('inner', [handoff]) + node = new MultiAgentNode({ orchestrator }) + + const { items } = await collectGenerator(node.stream([], state)) + + const streamEvents = items.filter((e) => e.type === 'nodeStreamUpdateEvent') as NodeStreamUpdateEvent[] + const wrapped = streamEvents.find((e) => e.nodeId === 'inner' && e.inner.event === handoff) + expect(wrapped).toBeDefined() + expect(wrapped!.nodeType).toBe('multiAgentNode') + }) + + it('returns orchestrator content', async () => { + const { result } = await collectGenerator(node.stream([], state)) + + expect(result).toEqual( + expect.objectContaining({ + nodeId: 'inner', + status: Status.COMPLETED, + content, + }) + ) + }) + + it('propagates FAILED status from inner orchestrator', async () => { + const failedOrchestrator: MultiAgent = { + id: 'inner', + invoke: async () => new MultiAgentResult({ results: [], duration: 0 }), + async *stream() { + yield* [] + return new MultiAgentResult({ + status: Status.FAILED, + results: [ + new NodeResult({ nodeId: 'x', status: Status.FAILED, duration: 0, error: new Error('inner boom') }), + ], + content: [], + duration: 0, + error: new Error('inner boom'), + }) + }, + addHook: () => () => {}, + } + node = new MultiAgentNode({ orchestrator: failedOrchestrator }) + + const { result } = await collectGenerator(node.stream([], state)) + + expect(result.status).toBe(Status.FAILED) + expect(result.error?.message).toBe('inner boom') + }) + + it('propagates CANCELLED status from inner orchestrator', async () => { + const cancelledOrchestrator: MultiAgent = { + id: 'inner', + invoke: async () => new MultiAgentResult({ results: [], duration: 0 }), + async *stream() { + yield* [] + return new MultiAgentResult({ + status: Status.CANCELLED, + results: [], + content: [], + duration: 0, + }) + }, + addHook: () => () => {}, + } + node = new MultiAgentNode({ orchestrator: cancelledOrchestrator }) + + const { result } = await collectGenerator(node.stream([], state)) + + expect(result.status).toBe(Status.CANCELLED) + }) + }) + + describe('orchestrator', () => { + it('exposes the wrapped orchestrator instance', () => { + const orchestrator = mockOrchestrator('test', []) + node = new MultiAgentNode({ orchestrator }) + expect(node.orchestrator).toBe(orchestrator) + }) + }) +}) diff --git a/strands-ts/src/multiagent/__tests__/queue.test.ts b/strands-ts/src/multiagent/__tests__/queue.test.ts new file mode 100644 index 0000000000..8d2d59e24c --- /dev/null +++ b/strands-ts/src/multiagent/__tests__/queue.test.ts @@ -0,0 +1,156 @@ +import { beforeEach, describe, expect, it } from 'vitest' +import { Queue } from '../queue.js' +import type { QueueData } from '../queue.js' +import type { Node } from '../nodes.js' +import { NodeResult, Status } from '../state.js' + +describe('Queue', () => { + let queue: Queue + let mockNode: Node + + beforeEach(() => { + mockNode = { id: 'node-1' } as Node + queue = new Queue() + }) + + describe('push and shift', () => { + it('dequeues in FIFO order', () => { + const data1: QueueData = { + type: 'result', + node: mockNode, + result: new NodeResult({ nodeId: 'node-1', status: Status.COMPLETED, duration: 10 }), + } + const data2: QueueData = { type: 'error', node: mockNode, error: new Error('fail') } + + queue.push(data1) + queue.push(data2) + + expect(queue.shift()?.data).toBe(data1) + expect(queue.shift()?.data).toBe(data2) + }) + + it('returns undefined when empty', () => { + expect(queue.shift()).toBeUndefined() + }) + + it('provides a no-op ack for fire-and-forget pushes', () => { + queue.push({ type: 'error', node: mockNode, error: new Error('a') }) + const entry = queue.shift()! + expect(() => entry.ack()).not.toThrow() + }) + }) + + describe('send', () => { + it('resolves when consumer calls ack', async () => { + const data: QueueData = { type: 'error', node: mockNode, error: new Error('a') } + let resolved = false + + const waiting = queue.send(data).then(() => { + resolved = true + }) + + await Promise.resolve() + expect(resolved).toBe(false) + + const entry = queue.shift()! + expect(entry.data).toBe(data) + + await Promise.resolve() + expect(resolved).toBe(false) + + entry.ack() + await waiting + expect(resolved).toBe(true) + }) + }) + + describe('size', () => { + it('reflects the current number of entries', () => { + expect(queue.size).toBe(0) + + queue.push({ type: 'error', node: mockNode, error: new Error('a') }) + queue.push({ type: 'error', node: mockNode, error: new Error('b') }) + expect(queue.size).toBe(2) + + queue.shift() + expect(queue.size).toBe(1) + }) + }) + + describe('wait', () => { + it('resolves immediately when entries are available', async () => { + queue.push({ type: 'error', node: mockNode, error: new Error('a') }) + + await queue.wait() + + expect(queue.size).toBe(1) + }) + + it('blocks until data is pushed', async () => { + let resolved = false + + const waiting = queue.wait().then(() => { + resolved = true + }) + + await Promise.resolve() + expect(resolved).toBe(false) + + queue.push({ type: 'error', node: mockNode, error: new Error('a') }) + + await waiting + expect(resolved).toBe(true) + }) + + it('blocks until data is sent', async () => { + let resolved = false + + const waiting = queue.wait().then(() => { + resolved = true + }) + + await Promise.resolve() + expect(resolved).toBe(false) + + const data: QueueData = { type: 'error', node: mockNode, error: new Error('a') } + // Don't await send — it won't resolve until ack + const sending = queue.send(data) + + await waiting + expect(resolved).toBe(true) + + // Clean up: ack so send resolves + queue.shift()!.ack() + await sending + }) + }) + + describe('dispose', () => { + it('resolves pending send acks and drains entries', async () => { + let resolved = false + const data: QueueData = { type: 'error', node: mockNode, error: new Error('a') } + const sending = queue.send(data).then(() => { + resolved = true + }) + + await Promise.resolve() + expect(resolved).toBe(false) + expect(queue.size).toBe(1) + + queue.dispose() + + await sending + expect(resolved).toBe(true) + expect(queue.size).toBe(0) + }) + + it('causes future send calls to resolve immediately', async () => { + queue.dispose() + + const data: QueueData = { type: 'error', node: mockNode, error: new Error('a') } + await queue.send(data) + + expect(queue.size).toBe(0) + }) + }) +}) diff --git a/strands-ts/src/multiagent/__tests__/snapshot.test.ts b/strands-ts/src/multiagent/__tests__/snapshot.test.ts new file mode 100644 index 0000000000..11fe70cdf2 --- /dev/null +++ b/strands-ts/src/multiagent/__tests__/snapshot.test.ts @@ -0,0 +1,198 @@ +import { describe, expect, it, vi, beforeEach, afterEach } from 'vitest' +import { Agent } from '../../agent/agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { TextBlock } from '../../types/messages.js' +import { SNAPSHOT_SCHEMA_VERSION } from '../../types/snapshot.js' +import type { Snapshot } from '../../types/snapshot.js' +import { takeSnapshot, loadSnapshot } from '../snapshot.js' +import { Graph } from '../graph.js' +import { Swarm } from '../swarm.js' +import { MultiAgentState, NodeResult, Status } from '../state.js' + +const MOCK_TIMESTAMP = '2026-01-15T12:00:00.000Z' + +function makeAgent(id: string, text = 'reply'): Agent { + const model = new MockMessageModel().addTurn(new TextBlock(text)) + return new Agent({ model, printer: false, id }) +} + +function makeGraph(id: string, agentIds: string[]): Graph { + return new Graph({ + id, + nodes: agentIds.map((aid) => makeAgent(aid)), + edges: agentIds.length > 1 ? [[agentIds[0]!, agentIds[1]!]] : [], + }) +} + +function makeSwarm(id: string, agentIds: string[]): Swarm { + return new Swarm({ + id, + nodes: agentIds.map((aid) => makeAgent(aid)), + }) +} + +function makeState(nodeIds: string[]): MultiAgentState { + return new MultiAgentState({ nodeIds }) +} + +describe('multiagent snapshot', () => { + beforeEach(() => { + vi.useFakeTimers() + vi.setSystemTime(new Date(MOCK_TIMESTAMP)) + }) + + afterEach(() => { + vi.useRealTimers() + }) + + describe('takeSnapshot', () => { + it('captures orchestratorId, serialized state, and appData', () => { + const graph = makeGraph('my-graph', ['a']) + const state = makeState(['a']) + state.steps = 3 + state.app.set('key', 'val') + + const snapshot = takeSnapshot(graph, state, { appData: { userId: 'u-1' } }) + + expect(snapshot).toEqual({ + scope: 'multiAgent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: MOCK_TIMESTAMP, + data: { + orchestratorId: 'my-graph', + state: expect.objectContaining({ steps: 3, app: { key: 'val' } }), + }, + appData: { userId: 'u-1' }, + }) + }) + + it('omits state when state parameter is undefined', () => { + const graph = makeGraph('g', ['a']) + + const snapshot = takeSnapshot(graph, undefined) + + expect(snapshot).toEqual({ + scope: 'multiAgent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: MOCK_TIMESTAMP, + data: { orchestratorId: 'g' }, + appData: {}, + }) + }) + + it('works with Swarm orchestrator', () => { + const swarm = makeSwarm('my-swarm', ['a', 'b']) + + const snapshot = takeSnapshot(swarm, makeState(['a', 'b'])) + + expect(snapshot).toEqual({ + scope: 'multiAgent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: MOCK_TIMESTAMP, + data: { + orchestratorId: 'my-swarm', + state: expect.any(Object), + }, + appData: {}, + }) + }) + }) + + describe('loadSnapshot', () => { + it('restores MultiAgentState for both Graph and Swarm', () => { + for (const [orchestrator, nodeIds] of [ + [makeGraph('g', ['a', 'b']), ['a', 'b']], + [makeSwarm('s', ['a', 'b']), ['a', 'b']], + ] as const) { + const state = makeState(nodeIds as unknown as string[]) + state.steps = 5 + state.results.push( + new NodeResult({ nodeId: 'a', status: Status.COMPLETED, duration: 100, content: [new TextBlock('done')] }) + ) + + const snapshot = takeSnapshot(orchestrator, state) + const restored = makeState([]) + loadSnapshot(orchestrator, snapshot, restored) + + expect(restored.steps).toBe(5) + expect(restored.results).toHaveLength(1) + expect(restored.results[0]!.nodeId).toBe('a') + } + }) + + it('does not modify state when snapshot has no state data', () => { + const graph = makeGraph('g', ['a']) + + const snapshotNoState = takeSnapshot(graph, undefined) + const state = makeState(['a']) + state.steps = 99 + loadSnapshot(graph, snapshotNoState, state) + expect(state.steps).toBe(99) + }) + + it('throws on wrong scope', () => { + const graph = makeGraph('g', ['a']) + const snapshot: Snapshot = { + scope: 'agent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: MOCK_TIMESTAMP, + data: { orchestratorId: 'g' }, + appData: {}, + } + + expect(() => loadSnapshot(graph, snapshot)).toThrow("Expected snapshot scope 'multiAgent', got 'agent'") + }) + + it('throws on unsupported schema version', () => { + const graph = makeGraph('g', ['a']) + const snapshot: Snapshot = { + scope: 'multiAgent', + schemaVersion: '99.0', + createdAt: MOCK_TIMESTAMP, + data: { orchestratorId: 'g' }, + appData: {}, + } + + expect(() => loadSnapshot(graph, snapshot)).toThrow('Unsupported snapshot schema version: 99.0') + }) + + it('throws on orchestratorId mismatch', () => { + const graph = makeGraph('g', ['a']) + const snapshot: Snapshot = { + scope: 'multiAgent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: MOCK_TIMESTAMP, + data: { orchestratorId: 'different-id' }, + appData: {}, + } + + expect(() => loadSnapshot(graph, snapshot)).toThrow( + "Snapshot orchestrator ID mismatch: expected 'g', got 'different-id'" + ) + }) + }) + + describe('round-trip', () => { + it('snapshot survives JSON.stringify/JSON.parse round-trip', () => { + const graph = makeGraph('g', ['a', 'b']) + const state = makeState(['a', 'b']) + state.steps = 7 + state.app.set('counter', 42) + state.results.push( + new NodeResult({ nodeId: 'a', status: Status.COMPLETED, duration: 200, content: [new TextBlock('result')] }) + ) + + const snapshot = takeSnapshot(graph, state, { appData: { key: 'value' } }) + const parsed = JSON.parse(JSON.stringify(snapshot)) as Snapshot + + const restored = makeState([]) + loadSnapshot(graph, parsed, restored) + + expect(restored.steps).toBe(7) + expect(restored.app.get('counter')).toBe(42) + expect(restored.results).toHaveLength(1) + expect(restored.results[0]!.nodeId).toBe('a') + expect((restored.results[0]!.content[0] as TextBlock).text).toBe('result') + }) + }) +}) diff --git a/strands-ts/src/multiagent/__tests__/state.test.ts b/strands-ts/src/multiagent/__tests__/state.test.ts new file mode 100644 index 0000000000..b01d26b351 --- /dev/null +++ b/strands-ts/src/multiagent/__tests__/state.test.ts @@ -0,0 +1,650 @@ +import { describe, expect, it } from 'vitest' +import { NodeResult, NodeState, MultiAgentResult, MultiAgentState, Status } from '../state.js' +import { TextBlock, ToolUseBlock } from '../../types/messages.js' +import type { JSONValue } from '../../types/json.js' +import { + stateToJSONSymbol, + loadStateFromJSONSymbol, + serializeStateSerializable, + loadStateSerializable, +} from '../../types/serializable.js' +import { Interrupt } from '../../interrupt.js' +import { InterruptResponseContent } from '../../types/interrupt.js' +import { extractResumeResponses, groupInterruptResponsesByNode } from '../multiagent.js' + +describe('NodeResult', () => { + describe('toJSON / fromJSON', () => { + it('round-trips a completed result with text content', () => { + const original = new NodeResult({ + nodeId: 'agent-1', + status: Status.COMPLETED, + duration: 150, + content: [new TextBlock('hello world')], + }) + + const restored = NodeResult.fromJSON(original.toJSON()) + + expect(restored).toMatchObject({ + nodeId: 'agent-1', + status: Status.COMPLETED, + duration: 150, + }) + expect(restored.content).toHaveLength(1) + expect(restored.content[0]).toBeInstanceOf(TextBlock) + expect((restored.content[0] as TextBlock).text).toBe('hello world') + expect(restored.error).toBeUndefined() + expect(restored.structuredOutput).toBeUndefined() + }) + + it('round-trips a failed result with error', () => { + const original = new NodeResult({ + nodeId: 'agent-2', + status: Status.FAILED, + duration: 50, + error: new Error('something broke'), + }) + + const restored = NodeResult.fromJSON(original.toJSON()) + + expect(restored).toMatchObject({ + status: Status.FAILED, + content: [], + }) + expect(restored.error).toBeInstanceOf(Error) + expect(restored.error!.message).toBe('something broke') + }) + + it('round-trips structuredOutput with nested objects', () => { + const output = { name: 'Alice', scores: [1, 2, 3], nested: { deep: true } } + const original = new NodeResult({ + nodeId: 'agent-3', + status: Status.COMPLETED, + duration: 100, + structuredOutput: output, + }) + + const restored = NodeResult.fromJSON(original.toJSON()) + + expect(restored.structuredOutput).toEqual(output) + }) + + it('preserves structuredOutput when value is null', () => { + const original = new NodeResult({ + nodeId: 'agent-4', + status: Status.COMPLETED, + duration: 10, + structuredOutput: null, + }) + + const restored = NodeResult.fromJSON(original.toJSON()) + + expect(restored.structuredOutput).toBeNull() + }) + + it('preserves structuredOutput when value is a primitive', () => { + const original = new NodeResult({ + nodeId: 'agent-5', + status: Status.COMPLETED, + duration: 10, + structuredOutput: 42, + }) + + const restored = NodeResult.fromJSON(original.toJSON()) + + expect(restored.structuredOutput).toBe(42) + }) + + it('round-trips multiple content blocks including tool use', () => { + const original = new NodeResult({ + nodeId: 'agent-6', + status: Status.COMPLETED, + duration: 200, + content: [ + new TextBlock('thinking...'), + new ToolUseBlock({ toolUseId: 'tu-1', name: 'calculator', input: { expr: '2+2' } }), + ], + }) + + const restored = NodeResult.fromJSON(original.toJSON()) + + expect(restored.content).toHaveLength(2) + expect(restored.content[0]).toBeInstanceOf(TextBlock) + expect(restored.content[1]).toBeInstanceOf(ToolUseBlock) + expect((restored.content[1] as ToolUseBlock).name).toBe('calculator') + }) + + it('round-trips a cancelled result with empty content', () => { + const original = new NodeResult({ + nodeId: 'agent-7', + status: Status.CANCELLED, + duration: 0, + }) + + const restored = NodeResult.fromJSON(original.toJSON()) + + expect(restored).toMatchObject({ + status: Status.CANCELLED, + content: [], + duration: 0, + }) + }) + + it('omits error from JSON when not present', () => { + const original = new NodeResult({ + nodeId: 'n', + status: Status.COMPLETED, + duration: 1, + }) + + const json = original.toJSON() as Record + + expect('error' in json).toBe(false) + }) + + it('omits structuredOutput from JSON when not present', () => { + const original = new NodeResult({ + nodeId: 'n', + status: Status.COMPLETED, + duration: 1, + }) + + const json = original.toJSON() as Record + + expect('structuredOutput' in json).toBe(false) + }) + + it('round-trips usage with all fields', () => { + const original = new NodeResult({ + nodeId: 'a', + status: Status.COMPLETED, + duration: 100, + usage: { + inputTokens: 10, + outputTokens: 20, + totalTokens: 30, + cacheReadInputTokens: 5, + cacheWriteInputTokens: 3, + }, + }) + + const restored = NodeResult.fromJSON(original.toJSON()) + + expect(restored.usage).toEqual(original.usage) + }) + + it('omits usage from JSON when not present', () => { + const original = new NodeResult({ + nodeId: 'n', + status: Status.COMPLETED, + duration: 1, + }) + + const json = original.toJSON() as Record + + expect('usage' in json).toBe(false) + }) + }) +}) + +describe('NodeState', () => { + describe('stateToJSONSymbol / loadStateFromJSONSymbol', () => { + it('round-trips a fresh node state', () => { + const original = new NodeState() + + const restored = new NodeState() + restored[loadStateFromJSONSymbol](original[stateToJSONSymbol]()) + + expect(restored).toMatchObject({ + status: Status.PENDING, + terminus: false, + startTime: original.startTime, + results: [], + }) + }) + + it('round-trips a node state with results', () => { + const original = new NodeState() + original.status = Status.COMPLETED + original.terminus = true + original.results.push( + new NodeResult({ nodeId: 'a', status: Status.COMPLETED, duration: 100, content: [new TextBlock('done')] }) + ) + original.results.push( + new NodeResult({ nodeId: 'a', status: Status.FAILED, duration: 50, error: new Error('retry failed') }) + ) + + const restored = new NodeState() + restored[loadStateFromJSONSymbol](original[stateToJSONSymbol]()) + + expect(restored).toMatchObject({ + status: Status.COMPLETED, + terminus: true, + }) + expect(restored.results).toHaveLength(2) + expect(restored.results[0]).toMatchObject({ status: Status.COMPLETED }) + expect(restored.results[1]).toMatchObject({ status: Status.FAILED }) + expect(restored.results[1]!.error!.message).toBe('retry failed') + }) + + it('preserves content accessor after round-trip', () => { + const original = new NodeState() + original.results.push( + new NodeResult({ nodeId: 'a', status: Status.COMPLETED, duration: 10, content: [new TextBlock('last')] }) + ) + + const restored = new NodeState() + restored[loadStateFromJSONSymbol](original[stateToJSONSymbol]()) + + expect(restored.content).toHaveLength(1) + expect((restored.content[0] as TextBlock).text).toBe('last') + }) + + it('loads state into existing instance via loadStateFromJSONSymbol', () => { + const original = new NodeState() + original.status = Status.COMPLETED + original.terminus = true + original.results.push(new NodeResult({ nodeId: 'a', status: Status.COMPLETED, duration: 100 })) + + const target = new NodeState() + target[loadStateFromJSONSymbol](original[stateToJSONSymbol]()) + + expect(target).toMatchObject({ + status: Status.COMPLETED, + terminus: true, + startTime: original.startTime, + }) + expect(target.results).toHaveLength(1) + }) + + it('round-trips interrupts and interruptedSnapshot on an INTERRUPTED node', () => { + const original = new NodeState() + original.status = Status.INTERRUPTED + original.interrupts = [new Interrupt({ id: 'tool:1:confirm', name: 'confirm', reason: 'need it' })] + original.interruptedSnapshot = { + scope: 'agent', + schemaVersion: '1.0', + createdAt: '2026-01-01T00:00:00Z', + data: { messages: [], interrupts: { activated: true, interrupts: {} } }, + appData: {}, + } + + const restored = new NodeState() + loadStateSerializable(restored, serializeStateSerializable(original)) + + expect(restored).toEqual(original) + }) + + it('clears interruptedSnapshot when it is absent from the serialized state', () => { + const original = new NodeState() + original.status = Status.COMPLETED + + const restored = new NodeState() + restored.interruptedSnapshot = { + scope: 'agent', + schemaVersion: '1.0', + createdAt: '2026-01-01T00:00:00Z', + data: {}, + appData: {}, + } + loadStateSerializable(restored, serializeStateSerializable(original)) + + expect(restored).toEqual(original) + }) + }) +}) + +describe('MultiAgentResult', () => { + describe('toJSON / fromJSON', () => { + it('round-trips a completed result', () => { + const nodeResult = new NodeResult({ + nodeId: 'writer', + status: Status.COMPLETED, + duration: 300, + content: [new TextBlock('final answer')], + }) + const original = new MultiAgentResult({ + results: [nodeResult], + content: [new TextBlock('final answer')], + duration: 500, + }) + + const restored = MultiAgentResult.fromJSON(original.toJSON()) + + expect(restored).toMatchObject({ + status: Status.COMPLETED, + duration: 500, + }) + expect(restored.results).toHaveLength(1) + expect(restored.results[0]).toMatchObject({ nodeId: 'writer' }) + expect(restored.content).toHaveLength(1) + expect((restored.content[0] as TextBlock).text).toBe('final answer') + expect(restored.error).toBeUndefined() + }) + + it('round-trips a failed result with error', () => { + const original = new MultiAgentResult({ + status: Status.FAILED, + results: [], + duration: 10, + error: new Error('orchestration failed'), + }) + + const restored = MultiAgentResult.fromJSON(original.toJSON()) + + expect(restored).toMatchObject({ status: Status.FAILED }) + expect(restored.error).toBeInstanceOf(Error) + expect(restored.error!.message).toBe('orchestration failed') + }) + + it('preserves explicit status override', () => { + const nodeResult = new NodeResult({ + nodeId: 'a', + status: Status.COMPLETED, + duration: 10, + }) + const original = new MultiAgentResult({ + status: Status.CANCELLED, + results: [nodeResult], + duration: 20, + }) + + const restored = MultiAgentResult.fromJSON(original.toJSON()) + + expect(restored.status).toBe(Status.CANCELLED) + }) + + it('round-trips with empty results and content', () => { + const original = new MultiAgentResult({ + results: [], + duration: 0, + }) + + const restored = MultiAgentResult.fromJSON(original.toJSON()) + + expect(restored).toMatchObject({ + status: Status.COMPLETED, + results: [], + content: [], + }) + }) + + it('preserves aggregated usage after round-trip', () => { + const original = new MultiAgentResult({ + results: [ + new NodeResult({ + nodeId: 'a', + status: Status.COMPLETED, + duration: 10, + usage: { inputTokens: 5, outputTokens: 10, totalTokens: 15 }, + }), + new NodeResult({ + nodeId: 'b', + status: Status.COMPLETED, + duration: 20, + usage: { inputTokens: 3, outputTokens: 7, totalTokens: 10 }, + }), + ], + duration: 30, + }) + + expect(original.usage).toMatchObject({ inputTokens: 8, outputTokens: 17 }) + + const restored = MultiAgentResult.fromJSON(original.toJSON()) + + expect(restored.usage).toMatchObject({ inputTokens: 8, outputTokens: 17, totalTokens: 25 }) + }) + }) +}) + +describe('MultiAgentState', () => { + describe('stateToJSONSymbol / loadStateFromJSONSymbol', () => { + it('round-trips a fresh state with node IDs', () => { + const original = new MultiAgentState({ nodeIds: ['a', 'b', 'c'] }) + + const restored = new MultiAgentState() + restored[loadStateFromJSONSymbol](original[stateToJSONSymbol]()) + + expect(restored).toMatchObject({ + startTime: original.startTime, + steps: 0, + results: [], + }) + expect(restored.nodes.size).toBe(3) + expect(restored.node('a')).toBeDefined() + expect(restored.node('b')).toBeDefined() + expect(restored.node('c')).toBeDefined() + }) + + it('round-trips state with steps and results', () => { + const original = new MultiAgentState({ nodeIds: ['researcher', 'writer'] }) + original.steps = 3 + original.results.push( + new NodeResult({ + nodeId: 'researcher', + status: Status.COMPLETED, + duration: 200, + content: [new TextBlock('research findings')], + }) + ) + original.results.push( + new NodeResult({ + nodeId: 'writer', + status: Status.COMPLETED, + duration: 150, + content: [new TextBlock('polished output')], + }) + ) + + const restored = new MultiAgentState() + restored[loadStateFromJSONSymbol](original[stateToJSONSymbol]()) + + expect(restored.steps).toBe(3) + expect(restored.results).toHaveLength(2) + expect(restored.results[0]).toMatchObject({ nodeId: 'researcher' }) + expect(restored.results[1]).toMatchObject({ nodeId: 'writer' }) + }) + + it('round-trips app state', () => { + const original = new MultiAgentState() + original.app.set('counter', 42) + original.app.set('config', { nested: { key: 'value' }, list: [1, 2, 3] }) + + const restored = new MultiAgentState() + restored[loadStateFromJSONSymbol](original[stateToJSONSymbol]()) + + expect(restored.app.get('counter')).toBe(42) + expect(restored.app.get('config')).toEqual({ nested: { key: 'value' }, list: [1, 2, 3] }) + }) + + it('round-trips node states with modified status and results', () => { + const original = new MultiAgentState({ nodeIds: ['agent-1'] }) + const ns = original.node('agent-1')! + ns.status = Status.COMPLETED + ns.terminus = true + ns.results.push(new NodeResult({ nodeId: 'agent-1', status: Status.COMPLETED, duration: 100 })) + + const restored = new MultiAgentState() + restored[loadStateFromJSONSymbol](original[stateToJSONSymbol]()) + + const restoredNs = restored.node('agent-1')! + expect(restoredNs).toMatchObject({ + status: Status.COMPLETED, + terminus: true, + }) + expect(restoredNs.results).toHaveLength(1) + }) + + it('round-trips an empty state (no node IDs)', () => { + const original = new MultiAgentState() + + const restored = new MultiAgentState() + restored[loadStateFromJSONSymbol](original[stateToJSONSymbol]()) + + expect(restored).toMatchObject({ + steps: 0, + results: [], + }) + expect(restored.nodes.size).toBe(0) + }) + + it('handles loadStateFromJSONSymbol with missing nodes key gracefully', () => { + const json = { + startTime: 1000, + steps: 0, + results: [], + app: {}, + } as JSONValue + + const restored = new MultiAgentState() + restored[loadStateFromJSONSymbol](json) + + expect(restored).toMatchObject({ startTime: 1000 }) + expect(restored.nodes.size).toBe(0) + }) + + it('preserves startTime exactly (no re-initialization)', () => { + const json = { + startTime: 1234567890, + steps: 5, + results: [], + app: {}, + nodes: {}, + } as JSONValue + + const restored = new MultiAgentState() + restored[loadStateFromJSONSymbol](json) + + expect(restored).toMatchObject({ + startTime: 1234567890, + steps: 5, + }) + }) + + it('round-trips _pendingInput as a string', () => { + const original = new MultiAgentState() + original._pendingInput = 'hello' + const restored = new MultiAgentState() + loadStateSerializable(restored, JSON.parse(JSON.stringify(serializeStateSerializable(original))) as JSONValue) + expect(restored._pendingInput).toBe('hello') + }) + + it('rehydrates _pendingInput ContentBlock[] to ContentBlock instances', () => { + // Round-trips through JSON.stringify/parse to simulate FileStorage persistence, + // then asserts the restored entries are real ContentBlock instances rather than + // raw data objects — agent message construction depends on instance shape for + // some downstream code paths. + const original = new MultiAgentState() + original._pendingInput = [new TextBlock('question')] + const serialized = JSON.parse(JSON.stringify(serializeStateSerializable(original))) as JSONValue + const restored = new MultiAgentState() + loadStateSerializable(restored, serialized) + + expect(restored._pendingInput).toEqual([new TextBlock('question')]) + expect((restored._pendingInput as TextBlock[])[0]).toBeInstanceOf(TextBlock) + }) + }) +}) + +describe('MultiAgentResult._resolveStatus precedence', () => { + function makeResult( + status: typeof Status.COMPLETED | typeof Status.FAILED | typeof Status.CANCELLED | typeof Status.INTERRUPTED, + nodeId = 'n' + ): NodeResult { + return new NodeResult({ nodeId, status, duration: 1 }) + } + + it('returns COMPLETED when all node results are completed', () => { + const r = new MultiAgentResult({ + results: [makeResult(Status.COMPLETED), makeResult(Status.COMPLETED)], + duration: 10, + }) + expect(r.status).toBe(Status.COMPLETED) + }) + + it('FAILED outranks INTERRUPTED', () => { + const r = new MultiAgentResult({ + results: [makeResult(Status.INTERRUPTED), makeResult(Status.FAILED)], + duration: 10, + }) + expect(r.status).toBe(Status.FAILED) + }) + + it('INTERRUPTED outranks CANCELLED', () => { + const r = new MultiAgentResult({ + results: [makeResult(Status.CANCELLED), makeResult(Status.INTERRUPTED)], + duration: 10, + }) + expect(r.status).toBe(Status.INTERRUPTED) + }) + + it('CANCELLED outranks COMPLETED', () => { + const r = new MultiAgentResult({ + results: [makeResult(Status.COMPLETED), makeResult(Status.CANCELLED)], + duration: 10, + }) + expect(r.status).toBe(Status.CANCELLED) + }) + + it('FAILED outranks CANCELLED', () => { + const r = new MultiAgentResult({ results: [makeResult(Status.CANCELLED), makeResult(Status.FAILED)], duration: 10 }) + expect(r.status).toBe(Status.FAILED) + }) +}) + +describe('groupInterruptResponsesByNode', () => { + function makeState(nodeInterrupts: Record): MultiAgentState { + const state = new MultiAgentState({ nodeIds: Object.keys(nodeInterrupts) }) + for (const [id, interrupts] of Object.entries(nodeInterrupts)) { + state.node(id)!.interrupts = interrupts + } + return state + } + + it('groups responses by the node whose interrupts match each id', () => { + const state = makeState({ + a: [new Interrupt({ id: 'tool:1:confirm', name: 'confirm' })], + b: [new Interrupt({ id: 'tool:2:approve', name: 'approve' })], + }) + const responses = [ + new InterruptResponseContent({ interruptId: 'tool:1:confirm', response: 'yes' }), + new InterruptResponseContent({ interruptId: 'tool:2:approve', response: 'ok' }), + ] + + const grouped = groupInterruptResponsesByNode(responses, state) + + expect(grouped.get('a')).toHaveLength(1) + expect(grouped.get('b')).toHaveLength(1) + expect(grouped.get('a')?.[0]?.interruptResponse.interruptId).toBe('tool:1:confirm') + }) + + it('throws when a response id does not match any node interrupt', () => { + const state = makeState({ a: [new Interrupt({ id: 'tool:1:confirm', name: 'confirm' })] }) + const responses = [new InterruptResponseContent({ interruptId: 'tool:missing:xyz', response: 'yes' })] + + expect(() => groupInterruptResponsesByNode(responses, state)).toThrow(/tool:missing:xyz/) + }) + + it('returns an empty map for empty responses', () => { + const state = makeState({ a: [new Interrupt({ id: 'tool:1:confirm', name: 'confirm' })] }) + const grouped = groupInterruptResponsesByNode([], state) + expect(grouped.size).toBe(0) + }) +}) + +describe('extractResumeResponses', () => { + it('throws when interrupt responses are mixed with other content', () => { + // Cast through `unknown` since the public type rejects mixed arrays at compile-time; + // this test pins the runtime guard for callers that bypass typing. + const mixed = [ + new InterruptResponseContent({ interruptId: 'tool:1:confirm', response: 'ok' }), + new TextBlock('stray content'), + ] as unknown as InterruptResponseContent[] + expect(() => extractResumeResponses(mixed)).toThrow(TypeError) + }) + + it('returns undefined for empty input or non-response arrays', () => { + expect(extractResumeResponses([])).toBeUndefined() + expect(extractResumeResponses('hello')).toBeUndefined() + expect(extractResumeResponses([new TextBlock('hi')])).toBeUndefined() + }) +}) diff --git a/strands-ts/src/multiagent/__tests__/swarm.invocation-state.test.ts b/strands-ts/src/multiagent/__tests__/swarm.invocation-state.test.ts new file mode 100644 index 0000000000..04ef33ff6c --- /dev/null +++ b/strands-ts/src/multiagent/__tests__/swarm.invocation-state.test.ts @@ -0,0 +1,70 @@ +import { describe, expect, it } from 'vitest' +import { Agent } from '../../agent/agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { BeforeModelCallEvent } from '../../hooks/events.js' +import type { JSONValue } from '../../types/json.js' +import { Swarm } from '../swarm.js' +import type { InvocationState } from '../../types/agent.js' + +/** + * Agent that hands off to `nextAgentId` via the structured-output tool, or + * terminates when `nextAgentId` is undefined. + */ +function makeHandoffAgent(id: string, nextAgentId: string | undefined, message: string): Agent { + const handoff: { agentId?: string; message: string } = { message } + if (nextAgentId !== undefined) handoff.agentId = nextAgentId + + const model = new MockMessageModel().addTurn({ + type: 'toolUseBlock', + name: 'strands_structured_output', + toolUseId: `tool-${id}`, + input: handoff as JSONValue, + }) + return new Agent({ model, printer: false, id, description: `Agent ${id}` }) +} + +describe('Swarm invocationState forwarding', () => { + it('forwards invocationState to every node and mutations from one node are visible to the next', async () => { + const nodeAObserved: InvocationState[] = [] + const nodeBObserved: InvocationState[] = [] + + const agentA = makeHandoffAgent('a', 'b', 'to b') + agentA.addHook(BeforeModelCallEvent, (event) => { + nodeAObserved.push(event.invocationState) + event.invocationState.touchedByA = true + }) + + const agentB = makeHandoffAgent('b', undefined, 'done') + agentB.addHook(BeforeModelCallEvent, (event) => { + nodeBObserved.push(event.invocationState) + }) + + const swarm = new Swarm({ nodes: [agentA, agentB], start: 'a' }) + + const state: InvocationState = { requestId: 'r-1' } + await swarm.invoke('hello', { invocationState: state }) + + // Both nodes observe the same object reference. + expect(nodeAObserved[0]).toBe(state) + expect(nodeBObserved[0]).toBe(state) + + // Node B sees node A's mutation. + expect(nodeBObserved[0]?.touchedByA).toBe(true) + expect(state.touchedByA).toBe(true) + }) + + it('defaults invocationState to {} when none is passed', async () => { + let observed: InvocationState | undefined + + const agentA = makeHandoffAgent('a', undefined, 'done') + agentA.addHook(BeforeModelCallEvent, (event) => { + observed = event.invocationState + }) + + const swarm = new Swarm({ nodes: [agentA], start: 'a' }) + + await swarm.invoke('hello') + + expect(observed).toEqual({}) + }) +}) diff --git a/strands-ts/src/multiagent/__tests__/swarm.test.ts b/strands-ts/src/multiagent/__tests__/swarm.test.ts new file mode 100644 index 0000000000..55f0244cc5 --- /dev/null +++ b/strands-ts/src/multiagent/__tests__/swarm.test.ts @@ -0,0 +1,616 @@ +import { describe, expect, it, vi } from 'vitest' +import { Agent } from '../../agent/agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { collectGenerator } from '../../__fixtures__/model-test-helpers.js' +import { createCancellableAgent } from '../../__fixtures__/agent-helpers.js' +import { BeforeNodeCallEvent, MultiAgentInitializedEvent } from '../events.js' +import type { JSONValue } from '../../types/json.js' +import { TextBlock } from '../../types/messages.js' +import { Status, MultiAgentState } from '../state.js' +import { AgentNode } from '../nodes.js' +import { Swarm } from '../swarm.js' +import { SessionManager } from '../../session/session-manager.js' +import { MockSnapshotStorage } from '../../__fixtures__/mock-storage-provider.js' + +/** + * Creates an agent that produces a structured output handoff via the strands_structured_output tool. + * The agent exits after the structured output tool succeeds (early-exit behavior). + */ +function createHandoffAgent( + agentId: string, + handoff: { agentId?: string; message: string; context?: Record }, + description: string = `Agent ${agentId}` +): Agent { + const model = new MockMessageModel().addTurn({ + type: 'toolUseBlock', + name: 'strands_structured_output', + toolUseId: 'tool-1', + input: handoff as JSONValue, + }) + return new Agent({ model, printer: false, id: agentId, description }) +} + +/** + * Creates a simple agent that produces a final response (no handoff). + */ +function createFinalAgent(agentId: string, message: string, description: string = `Agent ${agentId}`): Agent { + return createHandoffAgent(agentId, { message }, description) +} + +describe('Swarm', () => { + describe('constructor', () => { + it('defaults id to "swarm"', () => { + const swarm = new Swarm({ + nodes: [createFinalAgent('a', 'hi')], + start: 'a', + }) + expect(swarm.id).toBe('swarm') + }) + + it('accepts a custom id', () => { + const swarm = new Swarm({ + nodes: [createFinalAgent('a', 'hi')], + start: 'a', + id: 'my-swarm', + }) + expect(swarm.id).toBe('my-swarm') + }) + + it('accepts AgentNodeOptions with per-node config', () => { + const swarm = new Swarm({ + nodes: [{ agent: createFinalAgent('a', 'hi') }], + start: 'a', + }) + expect(swarm.nodes.get('a')).toBeInstanceOf(AgentNode) + }) + + it('defaults start to the first node when not specified', () => { + const swarm = new Swarm({ + nodes: [createFinalAgent('first', 'hi'), createFinalAgent('second', 'bye')], + }) + + expect(swarm.start.id).toBe('first') + }) + + it('throws when start references unknown agent', () => { + expect( + () => + new Swarm({ + nodes: [createFinalAgent('a', 'hi')], + start: 'missing', + }) + ).toThrow('start= | start references unknown agent') + }) + + it('throws when nodes list is empty', () => { + expect(() => new Swarm({ nodes: [] })).toThrow('nodes list is empty') + }) + + it('throws on duplicate agent ids', () => { + const agent = createFinalAgent('a', 'hi') + expect( + () => + new Swarm({ + nodes: [agent, agent], + start: 'a', + }) + ).toThrow('agent_id= | duplicate agent id') + }) + + it('throws when maxSteps < 1', () => { + expect( + () => + new Swarm({ + nodes: [createFinalAgent('a', 'hi')], + start: 'a', + maxSteps: 0, + }) + ).toThrow('max_steps=<0> | must be at least 1') + }) + + it('defaults maxSteps, timeout, and nodeTimeout to Infinity', () => { + const swarm = new Swarm({ + nodes: [createFinalAgent('a', 'hi')], + start: 'a', + }) + expect(swarm.config.maxSteps).toBe(Infinity) + expect(swarm.config.timeout).toBe(Infinity) + expect(swarm.config.nodeTimeout).toBe(Infinity) + }) + + it('throws when timeout < 1', () => { + expect( + () => + new Swarm({ + nodes: [createFinalAgent('a', 'hi')], + start: 'a', + timeout: 0, + }) + ).toThrow('timeout=<0> | must be at least 1') + }) + + it('throws when nodeTimeout < 1', () => { + expect( + () => + new Swarm({ + nodes: [createFinalAgent('a', 'hi')], + start: 'a', + nodeTimeout: 0, + }) + ).toThrow('node_timeout=<0> | must be at least 1') + }) + }) + + describe('invoke', () => { + it('returns completed result with content and duration', async () => { + const swarm = new Swarm({ + nodes: [createFinalAgent('a', 'final answer')], + start: 'a', + }) + + const result = await swarm.invoke('hello') + + expect(result).toEqual( + expect.objectContaining({ + status: Status.COMPLETED, + duration: expect.any(Number), + content: [expect.objectContaining({ type: 'textBlock', text: 'final answer' })], + }) + ) + expect(result.results.map((r) => r.nodeId)).toStrictEqual(['a']) + expect(result.results[0]?.structuredOutput).toEqual({ message: 'final answer' }) + }) + + it('hands off from A to B and returns final output', async () => { + const swarm = new Swarm({ + nodes: [ + createHandoffAgent('a', { agentId: 'b', message: 'please handle this' }), + createFinalAgent('b', 'done by b'), + ], + start: 'a', + }) + + const result = await swarm.invoke('start') + + expect(result.status).toBe(Status.COMPLETED) + expect(result.results.map((r) => r.nodeId)).toStrictEqual(['a', 'b']) + }) + + it('chains handoffs across multiple agents (A → B → C)', async () => { + const swarm = new Swarm({ + nodes: [ + createHandoffAgent('a', { agentId: 'b', message: 'go to b' }), + createHandoffAgent('b', { agentId: 'c', message: 'go to c' }), + createFinalAgent('c', 'final from c'), + ], + start: 'a', + }) + + const result = await swarm.invoke('start') + + expect(result.status).toBe(Status.COMPLETED) + expect(result.results.map((r) => r.nodeId)).toStrictEqual(['a', 'b', 'c']) + }) + + it('passes serialized context in handoff input', async () => { + const contextData = { key: 'value', num: 42 } + const agentB = createFinalAgent('b', 'done') + const streamSpy = vi.spyOn(agentB, 'stream') + + const swarm = new Swarm({ + nodes: [createHandoffAgent('a', { agentId: 'b', message: 'handle this', context: contextData }), agentB], + start: 'a', + }) + + await swarm.invoke('start') + + expect(streamSpy).toHaveBeenCalled() + const args = streamSpy.mock.calls[0]![0] as TextBlock[] + const texts = args.map((b) => b.text) + expect(texts).toContainEqual('handle this') + expect(texts).toContainEqual(expect.stringContaining(JSON.stringify(contextData, null, 2))) + }) + + it('excludes current agent from handoff schema', async () => { + const agentA = createHandoffAgent('a', { agentId: 'b', message: 'go to b' }) + const agentB = createFinalAgent('b', 'done') + const streamSpyA = vi.spyOn(agentA, 'stream') + const streamSpyB = vi.spyOn(agentB, 'stream') + + const swarm = new Swarm({ + nodes: [agentA, agentB], + start: 'a', + }) + + await swarm.invoke('start') + + // Agent A's handoff schema allows B but rejects A + const schemaA = streamSpyA.mock.calls[0]![1]!.structuredOutputSchema! + expect(schemaA.parse({ agentId: 'b', message: 'ok' })).toStrictEqual({ agentId: 'b', message: 'ok' }) + expect(() => schemaA.parse({ agentId: 'a', message: 'ok' })).toThrow() + + // Agent B's handoff schema allows A but rejects B + const schemaB = streamSpyB.mock.calls[0]![1]!.structuredOutputSchema! + expect(schemaB.parse({ agentId: 'a', message: 'ok' })).toStrictEqual({ agentId: 'a', message: 'ok' }) + expect(() => schemaB.parse({ agentId: 'b', message: 'ok' })).toThrow() + }) + + it('throws when maxSteps is exceeded', async () => { + const swarm = new Swarm({ + nodes: [createHandoffAgent('a', { agentId: 'b', message: 'to b' }), createFinalAgent('b', 'done')], + start: 'a', + maxSteps: 1, + }) + + await expect(swarm.invoke('start')).rejects.toThrow('swarm reached step limit') + }) + + it('does not throw when swarm completes normally using exactly maxSteps', async () => { + const swarm = new Swarm({ + nodes: [createHandoffAgent('a', { agentId: 'b', message: 'to b' }), createFinalAgent('b', 'done by b')], + start: 'a', + maxSteps: 2, + }) + + const result = await swarm.invoke('start') + + expect(result.status).toBe(Status.COMPLETED) + expect(result.results.map((r) => r.nodeId)).toStrictEqual(['a', 'b']) + }) + + it('throws when a node exceeds nodeTimeout', async () => { + const swarm = new Swarm({ + nodes: [{ agent: createCancellableAgent('slow', 100) }], + start: 'slow', + nodeTimeout: 20, + }) + + await expect(swarm.invoke('go')).rejects.toThrow(/node_timeout=<20>, node_id=/) + }) + + it('applies per-node timeout over nodeTimeout', async () => { + const swarm = new Swarm({ + nodes: [{ agent: createCancellableAgent('slow', 100), timeout: 15 }], + start: 'slow', + nodeTimeout: 10_000, + }) + + await expect(swarm.invoke('go')).rejects.toThrow(/node_timeout=<15>, node_id=/) + }) + + it('does not throw when nodeTimeout is Infinity', async () => { + const swarm = new Swarm({ + nodes: [{ agent: createCancellableAgent('a', 20) }], + start: 'a', + nodeTimeout: Infinity, + }) + + const result = await swarm.invoke('go') + expect(result.status).toBe(Status.COMPLETED) + }) + + it('per-node timeout of Infinity disables a finite nodeTimeout', async () => { + const swarm = new Swarm({ + nodes: [{ agent: createCancellableAgent('slow', 30), timeout: Infinity }], + start: 'slow', + nodeTimeout: 10, + }) + + const result = await swarm.invoke('go') + expect(result.status).toBe(Status.COMPLETED) + }) + + it('throws when timeout is exceeded between steps', async () => { + const swarm = new Swarm({ + nodes: [ + { agent: createCancellableAgent('a', 30, { agentId: 'b', message: 'to b' }) }, + { agent: createCancellableAgent('b', 30) }, + ], + start: 'a', + timeout: 20, + }) + + await expect(swarm.invoke('go')).rejects.toThrow(/timeout=<20>/) + }) + + it('aborts an in-flight node when the swarm timeout expires mid-step', async () => { + const swarm = new Swarm({ + nodes: [{ agent: createCancellableAgent('slow', 200) }], + start: 'slow', + timeout: 20, + }) + + await expect(swarm.invoke('go')).rejects.toThrow(/timeout=<20>/) + }) + + it('returns cancelled result with custom message when cancel is a string', async () => { + const swarm = new Swarm({ + nodes: [createFinalAgent('a', 'hi')], + start: 'a', + }) + + swarm.addHook(BeforeNodeCallEvent, (event: BeforeNodeCallEvent) => { + event.cancel = 'agent not ready' + }) + + const { items, result } = await collectGenerator(swarm.stream('go')) + + expect(result.status).toBe(Status.CANCELLED) + + const cancelEvent = items.find((e) => e.type === 'nodeCancelEvent') + expect(cancelEvent).toEqual( + expect.objectContaining({ nodeId: 'a', state: expect.any(MultiAgentState), message: 'agent not ready' }) + ) + }) + + it('returns failed result when agent throws', async () => { + const model = new MockMessageModel().addTurn(new Error('agent exploded')) + const agent = new Agent({ model, printer: false, id: 'a', description: 'Agent a' }) + + const swarm = new Swarm({ + nodes: [{ agent }], + start: 'a', + }) + + const result = await swarm.invoke('go') + + expect(result.status).toBe(Status.FAILED) + expect(result.results).toHaveLength(1) + expect(result.results[0]).toEqual(expect.objectContaining({ nodeId: 'a', status: Status.FAILED })) + }) + + it('calls initialize only once across invocations', async () => { + let callCount = 0 + + const swarm = new Swarm({ + nodes: [createFinalAgent('a', 'hi')], + start: 'a', + }) + + swarm.addHook(MultiAgentInitializedEvent, () => { + callCount++ + }) + + await swarm.invoke('first') + await swarm.invoke('second') + + expect(callCount).toBe(1) + }) + + it('preserves agent messages and state after execution', async () => { + const agent = createFinalAgent('a', 'reply') + const messagesBefore = [...agent.messages] + const stateBefore = agent.appState.getAll() + + const swarm = new Swarm({ + nodes: [agent], + start: 'a', + }) + + await swarm.invoke('hello') + + expect(agent.messages).toStrictEqual(messagesBefore) + expect(agent.appState.getAll()).toStrictEqual(stateBefore) + }) + }) + + describe('stream', () => { + it('yields lifecycle events in correct order for single agent', async () => { + const swarm = new Swarm({ + nodes: [createFinalAgent('a', 'reply')], + start: 'a', + }) + + const { items, result } = await collectGenerator(swarm.stream('go')) + const eventTypes = items.map((e) => e.type) + + expect(result.status).toBe(Status.COMPLETED) + expect(result.results.map((r) => r.nodeId)).toStrictEqual(['a']) + expect(eventTypes).toStrictEqual([ + 'beforeMultiAgentInvocationEvent', + 'beforeNodeCallEvent', + // nodeStreamUpdateEvents from agent execution + ...eventTypes.filter((t) => t === 'nodeStreamUpdateEvent'), + 'nodeResultEvent', + 'afterNodeCallEvent', + 'afterMultiAgentInvocationEvent', + 'multiAgentResultEvent', + ]) + }) + + it('yields handoff event between agents', async () => { + const swarm = new Swarm({ + nodes: [createHandoffAgent('a', { agentId: 'b', message: 'go' }), createFinalAgent('b', 'done')], + start: 'a', + }) + + const { items } = await collectGenerator(swarm.stream('start')) + const handoffEvents = items.filter((e) => e.type === 'multiAgentHandoffEvent') + + expect(handoffEvents).toHaveLength(1) + expect(handoffEvents[0]).toEqual( + expect.objectContaining({ + type: 'multiAgentHandoffEvent', + source: 'a', + targets: ['b'], + state: expect.any(MultiAgentState), + }) + ) + }) + + it('returns cancelled result with default message when cancel is true', async () => { + const swarm = new Swarm({ + nodes: [createFinalAgent('a', 'hi')], + start: 'a', + }) + + swarm.addHook(BeforeNodeCallEvent, (event: BeforeNodeCallEvent) => { + event.cancel = true + }) + + const { items, result } = await collectGenerator(swarm.stream('go')) + + expect(result.status).toBe(Status.CANCELLED) + expect(result.results).toHaveLength(1) + expect(result.results[0]).toEqual(expect.objectContaining({ nodeId: 'a', status: Status.CANCELLED, duration: 0 })) + + const cancelEvent = items.find((e) => e.type === 'nodeCancelEvent') + expect(cancelEvent).toEqual( + expect.objectContaining({ nodeId: 'a', state: expect.any(MultiAgentState), message: 'node cancelled by hook' }) + ) + }) + + it('returns cancelled result with custom message when cancel is a string', async () => { + const swarm = new Swarm({ + nodes: [createFinalAgent('a', 'hi')], + start: 'a', + }) + + swarm.addHook(BeforeNodeCallEvent, (event: BeforeNodeCallEvent) => { + event.cancel = 'agent not ready' + }) + + const { items, result } = await collectGenerator(swarm.stream('go')) + + expect(result.status).toBe(Status.CANCELLED) + + const cancelEvent = items.find((e) => e.type === 'nodeCancelEvent') + expect(cancelEvent).toEqual( + expect.objectContaining({ nodeId: 'a', state: expect.any(MultiAgentState), message: 'agent not ready' }) + ) + }) + }) + + describe('resume with session manager', () => { + function makeResumeSwarm(storage: MockSnapshotStorage, options: { maxSteps?: number } = {}): Swarm { + const sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + }) + const swarm = new Swarm({ + id: 'my-swarm', + nodes: [createHandoffAgent('a', { agentId: 'b', message: 'go to b' }), createFinalAgent('b', 'done by b')], + start: 'a', + plugins: [sessionManager], + ...options, + }) + return swarm + } + + it('resumes from the pending handoff target after a crash (A→B stopped, resumes at B)', async () => { + const storage = new MockSnapshotStorage() + + const swarm1 = makeResumeSwarm(storage, { maxSteps: 1 }) + await expect(swarm1.invoke('start')).rejects.toThrow('swarm reached step limit') + + const swarm2 = makeResumeSwarm(storage) + const result = await swarm2.invoke('start') + + expect(result.status).toBe(Status.COMPLETED) + expect(result.results.map((r) => r.nodeId)).toStrictEqual(['a', 'b']) + }) + + it('starts fresh when the previous run completed normally (no pending handoff)', async () => { + const storage = new MockSnapshotStorage() + const sessionManager1 = new SessionManager({ sessionId: 'test-session', storage: { snapshot: storage } }) + + const swarm1 = new Swarm({ + id: 'my-swarm', + nodes: [createFinalAgent('a', 'all done'), createFinalAgent('b', 'done by b')], + start: 'a', + plugins: [sessionManager1], + }) + + const result1 = await swarm1.invoke('start') + expect(result1.status).toBe(Status.COMPLETED) + expect(result1.results.map((r) => r.nodeId)).toStrictEqual(['a']) + + const result2 = await swarm1.invoke('start') + + expect(result2.status).toBe(Status.COMPLETED) + expect(result2.results.map((r) => r.nodeId)).toStrictEqual(['a']) + }) + + it('carries forward steps count from the previous invocation', async () => { + const storage = new MockSnapshotStorage() + + const swarm1 = makeResumeSwarm(storage, { maxSteps: 1 }) + await expect(swarm1.invoke('start')).rejects.toThrow('swarm reached step limit') + + const swarm2 = makeResumeSwarm(storage, { maxSteps: 2 }) + const result = await swarm2.invoke('start') + + expect(result.status).toBe(Status.COMPLETED) + expect(result.results.map((r) => r.nodeId)).toStrictEqual(['a', 'b']) + }) + + it('passes the last handoff context to the resumed node', async () => { + const storage = new MockSnapshotStorage() + const handoffContext = { research: 'quantum computing basics' } + + const sessionManager1 = new SessionManager({ sessionId: 'test-session', storage: { snapshot: storage } }) + const swarm1 = new Swarm({ + id: 'my-swarm', + nodes: [ + createHandoffAgent('a', { agentId: 'b', message: 'write this up', context: handoffContext }), + createFinalAgent('b', 'done'), + ], + start: 'a', + maxSteps: 1, + plugins: [sessionManager1], + }) + + await expect(swarm1.invoke('start')).rejects.toThrow('swarm reached step limit') + + const sessionManager2 = new SessionManager({ sessionId: 'test-session', storage: { snapshot: storage } }) + const agentB = createFinalAgent('b', 'done') + const streamSpy = vi.spyOn(agentB, 'stream') + + const swarm2 = new Swarm({ + id: 'my-swarm', + nodes: [createHandoffAgent('a', { agentId: 'b', message: 'write this up', context: handoffContext }), agentB], + start: 'a', + plugins: [sessionManager2], + }) + + await swarm2.invoke('start') + + expect(streamSpy).toHaveBeenCalled() + const args = streamSpy.mock.calls[0]![0] as TextBlock[] + const texts = args.map((b) => b.text) + expect(texts).toContainEqual('write this up') + expect(texts).toContainEqual(expect.stringContaining(JSON.stringify(handoffContext, null, 2))) + }) + + it('starts fresh when the resume target agent was removed from the swarm', async () => { + const storage = new MockSnapshotStorage() + + // First invocation: A hands off to B, maxSteps=1 stops + const sessionManager1 = new SessionManager({ sessionId: 'test-session', storage: { snapshot: storage } }) + const swarm1 = new Swarm({ + id: 'my-swarm', + nodes: [createHandoffAgent('a', { agentId: 'b', message: 'go to b' }), createFinalAgent('b', 'done by b')], + start: 'a', + maxSteps: 1, + plugins: [sessionManager1], + }) + + await expect(swarm1.invoke('start')).rejects.toThrow('swarm reached step limit') + + // Second invocation: swarm reconfigured — B removed, C added + const sessionManager2 = new SessionManager({ sessionId: 'test-session', storage: { snapshot: storage } }) + const swarm2 = new Swarm({ + id: 'my-swarm', + nodes: [createFinalAgent('a', 'fresh start'), createFinalAgent('c', 'done by c')], + start: 'a', + plugins: [sessionManager2], + }) + + const result = await swarm2.invoke('start') + + // B no longer exists, so _findResumeNode falls back to start node A + expect(result.status).toBe(Status.COMPLETED) + expect(result.results.map((r) => r.nodeId)).toStrictEqual(['a', 'a']) + }) + }) +}) diff --git a/strands-ts/src/multiagent/__tests__/swarm.tracer.test.ts b/strands-ts/src/multiagent/__tests__/swarm.tracer.test.ts new file mode 100644 index 0000000000..6452161984 --- /dev/null +++ b/strands-ts/src/multiagent/__tests__/swarm.tracer.test.ts @@ -0,0 +1,312 @@ +import { describe, expect, it, vi, beforeEach, type MockInstance } from 'vitest' +import { Agent } from '../../agent/agent.js' +import { MockMessageModel } from '../../__fixtures__/mock-message-model.js' +import { TextBlock } from '../../types/messages.js' +import type { JSONValue } from '../../types/json.js' +import { Tracer } from '../../telemetry/tracer.js' +import { Swarm } from '../swarm.js' +import { BeforeNodeCallEvent } from '../events.js' +import { Status } from '../state.js' + +interface MockTracerInstance { + startAgentSpan: MockInstance + endAgentSpan: MockInstance + startAgentLoopSpan: MockInstance + endAgentLoopSpan: MockInstance + startModelInvokeSpan: MockInstance + endModelInvokeSpan: MockInstance + startToolCallSpan: MockInstance + endToolCallSpan: MockInstance + startMultiAgentSpan: MockInstance + endMultiAgentSpan: MockInstance + startNodeSpan: MockInstance + endNodeSpan: MockInstance + withSpanContext: MockInstance +} + +vi.mock('../../telemetry/tracer.js', () => ({ + Tracer: vi.fn(function () { + return { + startAgentSpan: vi.fn().mockReturnValue({ mock: 'agentSpan' }), + endAgentSpan: vi.fn(), + startAgentLoopSpan: vi.fn().mockReturnValue({ mock: 'loopSpan' }), + endAgentLoopSpan: vi.fn(), + startModelInvokeSpan: vi.fn().mockReturnValue({ mock: 'modelSpan' }), + endModelInvokeSpan: vi.fn(), + startToolCallSpan: vi.fn().mockReturnValue({ mock: 'toolSpan' }), + endToolCallSpan: vi.fn(), + startMultiAgentSpan: vi.fn().mockReturnValue({ mock: 'multiAgentSpan' }), + endMultiAgentSpan: vi.fn(), + startNodeSpan: vi.fn().mockReturnValue({ mock: 'nodeSpan' }), + endNodeSpan: vi.fn(), + withSpanContext: vi.fn((_span: unknown, fn: () => unknown) => fn()), + } + }), +})) + +/** + * Returns the Tracer mock instance owned by the Swarm. + * Agents are constructed before the Swarm, so the Swarm's Tracer + * is always the last one created during Swarm construction. + */ +function getSwarmTracer(): MockTracerInstance { + return vi.mocked(Tracer).mock.results.at(-1)!.value +} + +function createHandoffAgent( + agentId: string, + handoff: { agentId?: string; message: string; context?: Record }, + description: string = `Agent ${agentId}` +): Agent { + const model = new MockMessageModel() + .addTurn({ + type: 'toolUseBlock', + name: 'strands_structured_output', + toolUseId: 'tool-1', + input: handoff as JSONValue, + }) + .addTurn(new TextBlock('Done')) + return new Agent({ model, printer: false, id: agentId, description }) +} + +function createHandoffAgentWithUsage( + agentId: string, + handoff: { agentId?: string; message: string; context?: Record }, + description: string = `Agent ${agentId}` +): Agent { + const model = new MockMessageModel() + .addTurn( + { + type: 'toolUseBlock', + name: 'strands_structured_output', + toolUseId: 'tool-1', + input: handoff as JSONValue, + }, + { usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } } + ) + .addTurn(new TextBlock('Done')) + return new Agent({ model, printer: false, id: agentId, description }) +} + +describe('Swarm tracer integration', () => { + let swarm: Swarm + let tracer: MockTracerInstance + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('multi-agent span lifecycle', () => { + it('starts and ends multi-agent span on successful invocation', async () => { + swarm = new Swarm({ id: 'test-swarm', nodes: [createHandoffAgent('a', { message: 'final response' })] }) + tracer = getSwarmTracer() + + await swarm.invoke('Hello') + + expect(tracer.startMultiAgentSpan.mock.calls).toEqual([ + [{ orchestratorId: 'test-swarm', orchestratorType: 'swarm', input: 'Hello' }], + ]) + expect(tracer.endMultiAgentSpan.mock.calls.length).toBe(1) + + const [span, endOpts] = tracer.endMultiAgentSpan.mock.calls[0]! + expect(span).toStrictEqual({ mock: 'multiAgentSpan' }) + expect(endOpts).toEqual({ + duration: expect.any(Number), + usage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 }, + }) + expect(endOpts.duration).toBeGreaterThanOrEqual(0) + }) + + it('passes exact usage from result to endMultiAgentSpan', async () => { + swarm = new Swarm({ id: 'test-swarm', nodes: [createHandoffAgentWithUsage('a', { message: 'final response' })] }) + tracer = getSwarmTracer() + + await swarm.invoke('Hello') + + const [, endOpts] = tracer.endMultiAgentSpan.mock.calls[0]! + expect(endOpts.usage).toStrictEqual({ inputTokens: 10, outputTokens: 5, totalTokens: 15 }) + }) + + it('ends multi-agent span with error when maxSteps exceeded', async () => { + swarm = new Swarm({ + nodes: [ + createHandoffAgent('a', { agentId: 'b', message: 'go' }), + createHandoffAgent('b', { agentId: 'a', message: 'go' }), + ], + maxSteps: 1, + }) + tracer = getSwarmTracer() + + await expect(swarm.invoke('Hello')).rejects.toThrow('swarm reached step limit') + + const [span, endOpts] = tracer.endMultiAgentSpan.mock.calls[0]! + expect(span).toStrictEqual({ mock: 'multiAgentSpan' }) + expect(endOpts).toEqual({ + duration: expect.any(Number), + error: expect.objectContaining({ + message: expect.stringContaining('swarm reached step limit'), + }), + }) + expect(endOpts.duration).toBeGreaterThanOrEqual(0) + }) + }) + + describe('node span lifecycle', () => { + it('starts and ends node span for each agent in handoff chain', async () => { + swarm = new Swarm({ + nodes: [ + createHandoffAgent('a', { agentId: 'b', message: 'go to b' }), + createHandoffAgent('b', { message: 'final response' }), + ], + }) + tracer = getSwarmTracer() + + await swarm.invoke('Hello') + + expect(tracer.startNodeSpan.mock.calls).toEqual([ + [{ nodeId: 'a', nodeType: 'agentNode' }], + [{ nodeId: 'b', nodeType: 'agentNode' }], + ]) + expect(tracer.endNodeSpan.mock.calls.length).toBe(2) + }) + + it('ends node span with COMPLETED status, duration, and zero usage on success', async () => { + swarm = new Swarm({ nodes: [createHandoffAgent('a', { message: 'final response' })] }) + tracer = getSwarmTracer() + + await swarm.invoke('Hello') + + const [span, endOpts] = tracer.endNodeSpan.mock.calls[0]! + expect(span).toStrictEqual({ mock: 'nodeSpan' }) + expect(endOpts).toEqual({ + status: Status.COMPLETED, + duration: expect.any(Number), + usage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 }, + }) + expect(endOpts.duration).toBeGreaterThanOrEqual(0) + }) + + it('passes exact usage from node result to endNodeSpan', async () => { + swarm = new Swarm({ nodes: [createHandoffAgentWithUsage('a', { message: 'final response' })] }) + tracer = getSwarmTracer() + + await swarm.invoke('Hello') + + const [, endOpts] = tracer.endNodeSpan.mock.calls[0]! + expect(endOpts.status).toBe(Status.COMPLETED) + expect(endOpts.usage).toStrictEqual({ inputTokens: 10, outputTokens: 5, totalTokens: 15 }) + }) + + it('ends node span with error when node agent throws', async () => { + const model = new MockMessageModel().addTurn(new Error('agent exploded')) + swarm = new Swarm({ nodes: [new Agent({ model, printer: false, id: 'a', description: 'Agent a' })] }) + tracer = getSwarmTracer() + + const result = await swarm.invoke('Hello') + + expect(result.status).toBe(Status.FAILED) + const [span, endOpts] = tracer.endNodeSpan.mock.calls[0]! + expect(span).toStrictEqual({ mock: 'nodeSpan' }) + expect(endOpts).toEqual({ + status: Status.FAILED, + duration: expect.any(Number), + }) + expect(endOpts.duration).toBeGreaterThanOrEqual(0) + }) + + it('ends node span with CANCELLED status and zero duration when cancelled by hook', async () => { + swarm = new Swarm({ nodes: [createHandoffAgent('a', { message: 'final response' })] }) + tracer = getSwarmTracer() + swarm.addHook(BeforeNodeCallEvent, (event) => { + event.cancel = 'cancelled by test' + }) + + await swarm.invoke('Hello') + + expect(tracer.endNodeSpan.mock.calls).toEqual([[{ mock: 'nodeSpan' }, { status: Status.CANCELLED, duration: 0 }]]) + }) + }) + + describe('null span handling', () => { + it('completes successfully when startMultiAgentSpan returns null', async () => { + swarm = new Swarm({ nodes: [createHandoffAgent('a', { message: 'final response' })] }) + tracer = getSwarmTracer() + tracer.startMultiAgentSpan.mockReturnValue(null) + + const result = await swarm.invoke('Hello') + + expect(result.status).toBe(Status.COMPLETED) + const [span] = tracer.endMultiAgentSpan.mock.calls[0]! + expect(span).toBeNull() + }) + + it('completes successfully when startNodeSpan returns null', async () => { + swarm = new Swarm({ nodes: [createHandoffAgent('a', { message: 'final response' })] }) + tracer = getSwarmTracer() + tracer.startNodeSpan.mockReturnValue(null) + + const result = await swarm.invoke('Hello') + + expect(result.status).toBe(Status.COMPLETED) + const [span] = tracer.endNodeSpan.mock.calls[0]! + expect(span).toBeNull() + }) + }) + + describe('span context propagation', () => { + it('passes node span to every withSpanContext call during node execution', async () => { + swarm = new Swarm({ nodes: [createHandoffAgent('a', { message: 'final response' })] }) + tracer = getSwarmTracer() + + await swarm.invoke('Hello') + + // First call: multiAgentSpan to create nodeSpan, then nodeSpan for node.stream() + gen.next() calls + const calls = tracer.withSpanContext.mock.calls + expect(calls.length).toBeGreaterThanOrEqual(3) + + // First call uses multiAgentSpan to create the nodeSpan + expect(calls[0]).toEqual([{ mock: 'multiAgentSpan' }, expect.any(Function)]) + + // Subsequent calls use nodeSpan for node execution + const subsequentCalls = calls.slice(1) + expect(subsequentCalls).toEqual( + expect.arrayContaining(Array(subsequentCalls.length).fill([{ mock: 'nodeSpan' }, expect.any(Function)])) + ) + }) + }) + + describe('handoff chain tracing', () => { + it('creates node spans for each agent in a multi-hop handoff', async () => { + swarm = new Swarm({ + nodes: [ + createHandoffAgent('a', { agentId: 'b', message: 'go to b' }), + createHandoffAgent('b', { agentId: 'c', message: 'go to c' }), + createHandoffAgent('c', { message: 'final response' }), + ], + }) + tracer = getSwarmTracer() + + await swarm.invoke('Hello') + + expect(tracer.startNodeSpan).toHaveBeenCalledTimes(3) + const nodeIds = tracer.startNodeSpan.mock.calls.map((call) => call[0].nodeId) + expect(nodeIds).toStrictEqual(['a', 'b', 'c']) + expect(tracer.endNodeSpan).toHaveBeenCalledTimes(3) + }) + + it('accumulates usage across handoff chain', async () => { + swarm = new Swarm({ + nodes: [ + createHandoffAgentWithUsage('a', { agentId: 'b', message: 'go to b' }), + createHandoffAgentWithUsage('b', { message: 'final response' }), + ], + }) + tracer = getSwarmTracer() + + await swarm.invoke('Hello') + + const [, endOpts] = tracer.endMultiAgentSpan.mock.calls[0]! + expect(endOpts.usage).toStrictEqual({ inputTokens: 20, outputTokens: 10, totalTokens: 30 }) + }) + }) +}) diff --git a/strands-ts/src/multiagent/edge.ts b/strands-ts/src/multiagent/edge.ts new file mode 100644 index 0000000000..df6315228e --- /dev/null +++ b/strands-ts/src/multiagent/edge.ts @@ -0,0 +1,40 @@ +import type { Node } from './nodes.js' +import type { MultiAgentState } from './state.js' + +/** + * Evaluates whether an edge should be traversed based on the current execution state. + */ +export type EdgeHandler = (state: MultiAgentState) => boolean | Promise + +/** + * Directed edge between two nodes. + */ +export class Edge { + readonly source: Node + readonly target: Node + /** Edge condition. The edge is always traversed when no handler is provided. */ + readonly handler: EdgeHandler + + constructor(data: { source: Node; target: Node; handler?: EdgeHandler }) { + this.source = data.source + this.target = data.target + this.handler = data.handler ?? ((): boolean => true) + } +} + +/** + * Options for creating an edge with an optional condition handler. + */ +export interface EdgeOptions { + source: string + target: string + handler?: EdgeHandler +} + +/** + * An edge definition accepted by orchestration constructors. + * + * Pass a `[source, target]` tuple for the simple case, or {@link EdgeOptions} + * when per-edge configuration is needed. + */ +export type EdgeDefinition = [source: string, target: string] | EdgeOptions diff --git a/strands-ts/src/multiagent/events.ts b/strands-ts/src/multiagent/events.ts new file mode 100644 index 0000000000..8f792fa4f9 --- /dev/null +++ b/strands-ts/src/multiagent/events.ts @@ -0,0 +1,338 @@ +import { HookableEvent, StreamEvent } from '../hooks/events.js' +import type { AgentStreamEvent, InvocationState } from '../types/agent.js' +import type { MultiAgentResult, MultiAgentState, NodeResult } from './state.js' +import type { MultiAgent } from './multiagent.js' +import type { NodeType } from './nodes.js' +import type { Interruptible } from '../interrupt.js' +import { interruptFromMultiAgentNode } from '../interrupt.js' +import type { InterruptParams } from '../types/interrupt.js' +import type { JSONValue } from '../types/json.js' + +/** + * Event triggered when a multi-agent orchestrator has finished initialization. + */ +export class MultiAgentInitializedEvent extends HookableEvent { + readonly type = 'multiAgentInitializedEvent' as const + readonly orchestrator: MultiAgent + + constructor(data: { orchestrator: MultiAgent }) { + super() + this.orchestrator = data.orchestrator + } + + toJSON(): Pick { + return { type: this.type } + } +} + +/** + * Event triggered before orchestrator execution starts. + */ +export class BeforeMultiAgentInvocationEvent extends HookableEvent { + readonly type = 'beforeMultiAgentInvocationEvent' as const + readonly orchestrator: MultiAgent + readonly state: MultiAgentState + readonly invocationState: InvocationState + + constructor(data: { orchestrator: MultiAgent; state: MultiAgentState; invocationState: InvocationState }) { + super() + this.orchestrator = data.orchestrator + this.state = data.state + this.invocationState = data.invocationState + } + + toJSON(): Pick { + return { type: this.type } + } +} + +/** + * Event triggered after orchestrator execution completes. + */ +export class AfterMultiAgentInvocationEvent extends HookableEvent { + readonly type = 'afterMultiAgentInvocationEvent' as const + readonly orchestrator: MultiAgent + readonly state: MultiAgentState + readonly invocationState: InvocationState + + constructor(data: { orchestrator: MultiAgent; state: MultiAgentState; invocationState: InvocationState }) { + super() + this.orchestrator = data.orchestrator + this.state = data.state + this.invocationState = data.invocationState + } + + override _shouldReverseCallbacks(): boolean { + return true + } + + toJSON(): Pick { + return { type: this.type } + } +} + +/** + * Event triggered before a node begins execution. + * Hook callbacks can set {@link cancel} to prevent the node from executing. + */ +export class BeforeNodeCallEvent extends HookableEvent implements Interruptible { + readonly type = 'beforeNodeCallEvent' as const + readonly orchestrator: MultiAgent + readonly state: MultiAgentState + readonly nodeId: string + readonly invocationState: InvocationState + + /** + * Set by hook callbacks to cancel node execution. + * When set to `true`, a default cancel message is used. + * When set to a string, that string is used as the cancel message. + */ + cancel: boolean | string = false + + constructor(data: { + orchestrator: MultiAgent + state: MultiAgentState + nodeId: string + invocationState: InvocationState + }) { + super() + this.orchestrator = data.orchestrator + this.state = data.state + this.nodeId = data.nodeId + this.invocationState = data.invocationState + } + + /** + * Raises an orchestrator-level interrupt that pauses the run before this node + * executes. If a prior resume has answered the interrupt, returns the response; + * otherwise throws an `InterruptError` and the orchestrator produces an + * INTERRUPTED result with the pending interrupt. + * + * The interrupt is stored on the target node's `NodeState.interrupts`, so resume + * via `InterruptResponseContent[]` routes through the same machinery as child- + * agent interrupts. + */ + interrupt(params: InterruptParams): T { + const nodeState = this.state.node(this.nodeId) + if (!nodeState) { + throw new Error(`node_id=<${this.nodeId}> | node state not found`) + } + return interruptFromMultiAgentNode( + nodeState.interrupts, + `multiagent-hook:beforeNodeCall:${this.nodeId}:${params.name}`, + params, + 'multiagent-hook' + ) + } + + toJSON(): Pick { + return { type: this.type, nodeId: this.nodeId } + } +} + +/** + * Event triggered after a node completes execution. + */ +export class AfterNodeCallEvent extends HookableEvent { + readonly type = 'afterNodeCallEvent' as const + readonly orchestrator: MultiAgent + readonly state: MultiAgentState + readonly nodeId: string + readonly invocationState: InvocationState + readonly error?: Error + + constructor(data: { + orchestrator: MultiAgent + state: MultiAgentState + nodeId: string + invocationState: InvocationState + error?: Error + }) { + super() + this.orchestrator = data.orchestrator + this.state = data.state + this.nodeId = data.nodeId + this.invocationState = data.invocationState + if (data.error !== undefined) { + this.error = data.error + } + } + + override _shouldReverseCallbacks(): boolean { + return true + } + + toJSON(): Pick & { error?: { message?: string } } { + return { + type: this.type, + nodeId: this.nodeId, + ...(this.error !== undefined && { error: { message: this.error.message } }), + } + } +} + +/** + * Tagged inner event from a node, discriminated by {@link source}. + * + * Use `inner.source` to determine the event origin, then `inner.event` + * to access the underlying event and switch on its `type`. + * + * Sources: + * - `'agent'` — the node wraps an {@link Agent} instance. The event is an + * {@link AgentStreamEvent} and can be narrowed via `event.type`. + * - `'multiAgent'` — the node wraps a nested orchestrator (e.g. {@link Graph} + * or {@link Swarm}). The event is a {@link MultiAgentStreamEvent} (excluding + * {@link NodeStreamUpdateEvent}, which passes through directly). + * - `'custom'` — the node wraps an {@link InvokableAgent} that is not an + * {@link Agent} instance (e.g. {@link A2AAgent} or a third-party implementation). + * The event is a {@link StreamEvent} with no further type narrowing available. + */ +export type NodeStreamUpdateInnerEvent = + | { readonly source: 'agent'; readonly event: AgentStreamEvent } + | { readonly source: 'multiAgent'; readonly event: Exclude } + | { readonly source: 'custom'; readonly event: StreamEvent } + +/** + * Wraps an inner streaming event from a node with the node's identity. + * Emitted during node execution to propagate agent-level or nested + * multi-agent events up to the orchestration layer. + */ +export class NodeStreamUpdateEvent extends HookableEvent { + readonly type = 'nodeStreamUpdateEvent' as const + readonly nodeId: string + readonly nodeType: NodeType + readonly state: MultiAgentState + readonly inner: NodeStreamUpdateInnerEvent + readonly invocationState: InvocationState + + constructor(data: { + nodeId: string + nodeType: NodeType + state: MultiAgentState + inner: NodeStreamUpdateInnerEvent + invocationState: InvocationState + }) { + super() + this.nodeId = data.nodeId + this.nodeType = data.nodeType + this.state = data.state + this.inner = data.inner + this.invocationState = data.invocationState + } + + toJSON(): Pick { + return { type: this.type, nodeId: this.nodeId, nodeType: this.nodeType, inner: this.inner } + } +} + +/** + * Event triggered when a node finishes execution. + * Wraps the {@link NodeResult} for the completed node. + */ +export class NodeResultEvent extends HookableEvent { + readonly type = 'nodeResultEvent' as const + readonly nodeId: string + readonly nodeType: NodeType + readonly state: MultiAgentState + readonly result: NodeResult + readonly invocationState: InvocationState + + constructor(data: { + nodeId: string + nodeType: NodeType + state: MultiAgentState + result: NodeResult + invocationState: InvocationState + }) { + super() + this.nodeId = data.nodeId + this.nodeType = data.nodeType + this.state = data.state + this.result = data.result + this.invocationState = data.invocationState + } + + toJSON(): Pick { + return { type: this.type, nodeId: this.nodeId, nodeType: this.nodeType, result: this.result } + } +} + +/** + * Event triggered when execution transitions between nodes. + */ +export class MultiAgentHandoffEvent extends HookableEvent { + readonly type = 'multiAgentHandoffEvent' as const + readonly source: string + readonly targets: string[] + readonly state: MultiAgentState + readonly invocationState: InvocationState + + constructor(data: { source: string; targets: string[]; state: MultiAgentState; invocationState: InvocationState }) { + super() + this.source = data.source + this.targets = data.targets + this.state = data.state + this.invocationState = data.invocationState + } + + toJSON(): Pick { + return { type: this.type, source: this.source, targets: this.targets } + } +} + +/** + * Event triggered when a node is cancelled via {@link BeforeNodeCallEvent.cancel}. + */ +export class NodeCancelEvent extends HookableEvent { + readonly type = 'nodeCancelEvent' as const + readonly nodeId: string + readonly state: MultiAgentState + readonly message: string + readonly invocationState: InvocationState + + constructor(data: { nodeId: string; state: MultiAgentState; message: string; invocationState: InvocationState }) { + super() + this.nodeId = data.nodeId + this.state = data.state + this.message = data.message + this.invocationState = data.invocationState + } + + toJSON(): Pick { + return { type: this.type, nodeId: this.nodeId, message: this.message } + } +} + +/** + * Event triggered as the final event in the multi-agent stream. + * Wraps the {@link MultiAgentResult} containing the aggregate outcome. + */ +export class MultiAgentResultEvent extends HookableEvent { + readonly type = 'multiAgentResultEvent' as const + readonly result: MultiAgentResult + readonly invocationState: InvocationState + + constructor(data: { result: MultiAgentResult; invocationState: InvocationState }) { + super() + this.result = data.result + this.invocationState = data.invocationState + } + + toJSON(): Pick { + return { type: this.type, result: this.result } + } +} + +/** + * Union of all multi-agent streaming events. + */ +export type MultiAgentStreamEvent = + | BeforeMultiAgentInvocationEvent + | AfterMultiAgentInvocationEvent + | BeforeNodeCallEvent + | AfterNodeCallEvent + | NodeStreamUpdateEvent + | NodeResultEvent + | NodeCancelEvent + | MultiAgentHandoffEvent + | MultiAgentResultEvent diff --git a/strands-ts/src/multiagent/graph.ts b/strands-ts/src/multiagent/graph.ts new file mode 100644 index 0000000000..b315ffe8f1 --- /dev/null +++ b/strands-ts/src/multiagent/graph.ts @@ -0,0 +1,851 @@ +import type { AttributeValue } from '@opentelemetry/api' +import type { InvocationState, InvokableAgent } from '../types/agent.js' +import type { MultiAgentContentInput, MultiAgentInput, MultiAgentInvokeOptions } from './multiagent.js' +import { + applyOrchestratorHookResponses, + dropStaleInterruptedResult, + extractResumeResponses, + groupInterruptResponsesByNode, + recordHookInterrupt, +} from './multiagent.js' +import type { ContentBlock } from '../types/messages.js' +import { TextBlock, contentBlockFromData } from '../types/messages.js' +import type { InterruptResponseContent } from '../types/interrupt.js' +import { InterruptError } from '../interrupt.js' +import { logger } from '../logging/logger.js' +import { warnOnce } from '../logging/warn-once.js' +import { HookableEvent } from '../hooks/events.js' +import { HookRegistryImplementation } from '../hooks/registry.js' +import type { HookCallback, HookableEventConstructor, HookCleanup } from '../hooks/types.js' +import type { MultiAgentPlugin } from './plugins.js' +import type { SessionManager } from '../session/session-manager.js' +import { MultiAgentPluginRegistry } from './plugins.js' +import type { NodeDefinition } from './nodes.js' +import { AgentNode, MultiAgentNode, Node } from './nodes.js' +import { MultiAgentState, MultiAgentResult, NodeResult, Status } from './state.js' +import type { MultiAgent } from './multiagent.js' +import { Swarm } from './swarm.js' +import type { MultiAgentStreamEvent } from './events.js' +import { + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentHandoffEvent, + MultiAgentInitializedEvent, + MultiAgentResultEvent, + NodeCancelEvent, + NodeResultEvent, +} from './events.js' +import type { EdgeDefinition } from './edge.js' +import { Edge } from './edge.js' +import { Queue } from './queue.js' +import { Tracer } from '../telemetry/tracer.js' +import type { Span } from '@opentelemetry/api' +import { normalizeError } from '../errors.js' + +/** + * Runtime configuration for graph execution. + */ +export interface GraphConfig { + /** Max nodes executing in parallel. Defaults to `Infinity` (no limit). */ + maxConcurrency?: number + /** Max total steps (prevents infinite loops in cyclic graphs). Defaults to `Infinity` (no limit). */ + maxSteps?: number + /** + * Wall-clock ceiling for the entire graph invocation, in milliseconds. Defaults to `Infinity` + * (no limit). + * + * Does not propagate into nested orchestrators wrapped via `MultiAgentNode` — a nested + * `Swarm`/`Graph` runs to completion under its own timeout config; the parent graph's + * timeout only fires once the nested node returns. + */ + timeout?: number + /** + * Fallback per-node wall-clock ceiling in milliseconds. Applied to any `AgentNode` that + * doesn't set its own `timeout`. Defaults to `Infinity` (no limit). + * + * Does not apply to `MultiAgentNode`. Set `timeout`/`nodeTimeout` on the nested + * orchestrator to bound it. + * + * Enforced via `AbortSignal` — cancellation is cooperative, so a tool that neither polls + * its cancel signal nor forwards it to a cancellable API can run past this deadline. + */ + nodeTimeout?: number +} + +/** + * Options for creating a Graph instance. + */ +export interface GraphOptions extends GraphConfig { + /** Unique identifier for this graph. Defaults to `'graph'`. */ + id?: string + /** Node definitions to construct the graph from. */ + nodes: NodeDefinition[] + /** Edge definitions describing connections between nodes. */ + edges: EdgeDefinition[] + /** Explicit source node IDs. If omitted, auto-detected from nodes with no incoming edges. */ + sources?: string[] + /** Session manager for saving and restoring graph sessions. */ + sessionManager?: SessionManager + /** Plugins for event-driven extensibility. */ + plugins?: MultiAgentPlugin[] + /** Custom trace attributes to include on all spans. */ + traceAttributes?: Record +} + +/** + * Directed graph orchestration pattern. + * + * Agents execute as nodes in a dependency graph, with edges defining execution order + * and optional conditions controlling routing. Source nodes (those with no incoming edges) + * run first, and downstream nodes execute once all their dependencies complete. Parallel + * execution is supported up to a configurable concurrency limit. + * + * Key design choices vs the Python SDK: + * - Construction uses a declarative options object rather than a mutable GraphBuilder. + * Nodes and edges are passed directly to the constructor. + * - Dependency resolution uses AND semantics: a node runs only when all incoming edges + * are satisfied. Python uses OR semantics, firing a node when any single incoming + * edge from the completed batch is satisfied. + * - Nodes are launched individually as they become ready (up to maxConcurrency). Python + * executes in discrete batches, waiting for the entire batch to complete before + * scheduling the next set of nodes. + * - Agent nodes are stateless by default (snapshot/restore on each execution). Python + * accumulates agent state across executions unless `reset_on_revisit` is enabled. + * - Node failures produce a FAILED result, allowing parallel paths to continue. + * MultiAgent-level limits (maxSteps) throw exceptions. Python does the inverse: + * node failures throw exceptions (fail-fast), while limit violations return a + * FAILED result. + * + * @example + * ```typescript + * const graph = new Graph({ + * nodes: [researcher, writer], + * edges: [['researcher', 'writer']], + * }) + * + * const result = await graph.invoke('Explain quantum computing') + * ``` + */ +export class Graph implements MultiAgent { + readonly id: string + readonly nodes: ReadonlyMap + readonly edges: readonly Edge[] + readonly config: Required + private readonly _pluginRegistry: MultiAgentPluginRegistry + private readonly _hookRegistry: HookRegistryImplementation + private readonly _sources: Node[] + private readonly _tracer: Tracer + readonly sessionManager?: SessionManager | undefined + private _initialized: boolean + /** + * State retained across invocations when a run ends INTERRUPTED. Lets + * `graph.invoke(responses)` resume on the same instance without requiring a + * SessionManager, mirroring single-agent ergonomics. Cleared when a run + * terminates in any non-INTERRUPTED state. + */ + private _pendingInterruptState?: MultiAgentState + + constructor(options: GraphOptions) { + const { id, nodes, edges, sources, sessionManager, plugins, traceAttributes, ...config } = options + + this.id = id ?? 'graph' + + this.config = { + maxConcurrency: config.maxConcurrency ?? Infinity, + maxSteps: config.maxSteps ?? Infinity, + timeout: config.timeout ?? Infinity, + nodeTimeout: config.nodeTimeout ?? Infinity, + } + this._validateConfig() + + if (this.config.maxSteps === Infinity && this.config.timeout === Infinity) { + warnOnce(logger, 'graph has no maxSteps or timeout set; execution is unbounded') + } + + this.nodes = this._resolveNodes(nodes) + this.edges = this._resolveEdges(edges) + this._sources = this._resolveSources(sources) + this._validateSources() + + this.sessionManager = sessionManager + + if (sessionManager && plugins?.some((p) => p.name === sessionManager.name)) { + throw new Error('sessionManager was provided as both a constructor argument and in the plugins array') + } + + this._hookRegistry = new HookRegistryImplementation() + this._pluginRegistry = new MultiAgentPluginRegistry([ + ...(plugins ?? []), + ...(sessionManager ? [sessionManager] : []), + ]) + this._tracer = new Tracer(traceAttributes) + this._initialized = false + } + + /** + * Initialize the graph. Invokes the {@link MultiAgentInitializedEvent} callback. + * Called automatically on first invocation. + */ + async initialize(): Promise { + if (this._initialized) return + await this._pluginRegistry.initialize(this) + await this._hookRegistry.invokeCallbacks(new MultiAgentInitializedEvent({ orchestrator: this })) + this._initialized = true + } + + /** + * Invoke graph and return final result (consumes stream). + * + * @param input - The input to pass to entry point nodes + * @param options - Optional per-invocation options (e.g., {@link InvocationState}) + * @returns Promise resolving to the final MultiAgentResult + */ + async invoke(input: MultiAgentInput, options?: MultiAgentInvokeOptions): Promise { + const gen = this.stream(input, options) + let next = await gen.next() + while (!next.done) { + next = await gen.next() + } + return next.value + } + + /** + * Register a hook callback for a specific graph event type. + * + * @param eventType - The event class constructor to register the callback for + * @param callback - The callback function to invoke when the event occurs + * @returns Cleanup function that removes the callback when invoked + */ + addHook(eventType: HookableEventConstructor, callback: HookCallback): HookCleanup { + return this._hookRegistry.addCallback(eventType, callback) + } + + /** + * Stream graph execution, yielding events as nodes execute. + * Invokes hook callbacks for each event before yielding. + * + * @param input - The input to pass to entry nodes + * @param options - Optional per-invocation options (e.g., {@link InvocationState}) + * @returns Async generator yielding streaming events and returning a MultiAgentResult + */ + async *stream( + input: MultiAgentInput, + options?: MultiAgentInvokeOptions + ): AsyncGenerator { + await this.initialize() + + // Resolve invocationState once; the same object is threaded to every node's + // child agent so mutations in one node are visible in the next. + const invocationState: InvocationState = options?.invocationState ?? {} + + // Hook invocation lives in `_stream` so hook-raised `InterruptError`s land in the + // same frame as the execution loop. + const gen = this._stream(input, invocationState, options?.cancelSignal) + try { + let next = await gen.next() + while (!next.done) { + yield next.value + next = await gen.next() + } + return next.value + } finally { + await gen.return(undefined as never) + } + } + + private async *_stream( + input: MultiAgentInput, + invocationState: InvocationState, + externalCancelSignal?: AbortSignal + ): AsyncGenerator { + // Reuse state from a prior INTERRUPTED run so `graph.invoke(responses)` can + // resume on the same instance without a SessionManager. + const state = this._pendingInterruptState ?? new MultiAgentState({ nodeIds: [...this.nodes.keys()] }) + delete this._pendingInterruptState + + const queue = new Queue() + const streams = new Map>() + + const multiAgentSpan = this._tracer.startMultiAgentSpan({ + orchestratorId: this.id, + orchestratorType: 'graph', + input, + }) + + // SessionManager (or plugins) may restore state.results here via the hook + yield* this._emit(new BeforeMultiAgentInvocationEvent({ orchestrator: this, state, invocationState })) + + // Resume input bypasses dependency resolution (routed by interrupt id). On fresh + // runs, stash the input so resume can replay it to hook-gated nodes that never ran. + // + // Example: Source node A has a `BeforeNodeCallEvent` hook that interrupts. The user + // calls `graph.invoke('original task')`, the hook fires before A executes, the run + // pauses with status INTERRUPTED. On resume, `graph.invoke([response])` only carries + // the response — A still needs `'original task'` as its input because A never ran + // and has no snapshot or upstream results to fall back on. `state._pendingInput` + // carries `'original task'` across the pause so resume can replay it. + const resumeResponses = extractResumeResponses(input) + const interruptResponsesByNode = resumeResponses ? groupInterruptResponsesByNode(resumeResponses, state) : undefined + let contentInput: MultiAgentContentInput | undefined + if (resumeResponses) { + contentInput = state._pendingInput as MultiAgentContentInput | undefined + } else { + contentInput = input as MultiAgentContentInput + state._pendingInput = contentInput + } + + const targets = interruptResponsesByNode + ? [...interruptResponsesByNode.keys()].map((id) => { + const node = this.nodes.get(id) + if (!node) { + throw new Error( + `node_id=<${id}>, graph_id=<${this.id}> | resume response targets a node missing from the graph; topology changed between save and resume?` + ) + } + return node + }) + : ((await this._findResumeTargets(state)) ?? [...this._sources]) + + // Wall-clock timeout for the whole graph invocation. External cancellation is kept + // on its own signal so the loop's abort checks below can distinguish the two causes + // and produce the right error message. + const execController = new AbortController() + const execTimeoutHandle = Number.isFinite(this.config.timeout) + ? setTimeout(() => execController.abort(), this.config.timeout) + : undefined + + const cancelSignal = externalCancelSignal + ? AbortSignal.any([execController.signal, externalCancelSignal]) + : execController.signal + + let interrupted = false + let caughtError: Error | undefined + let result: MultiAgentResult | undefined + try { + while (targets.length > 0 || streams.size > 0) { + if (execTimeoutHandle !== undefined && execController.signal.aborted) { + throw new Error(`timeout=<${this.config.timeout}>, graph_id=<${this.id}> | graph exceeded wall-clock budget`) + } + if (externalCancelSignal?.aborted) { + throw new Error(`graph_id=<${this.id}> | graph cancelled by external signal`) + } + while (!interrupted && targets.length > 0 && streams.size < this.config.maxConcurrency) { + const node = targets.shift()! + + this._checkSteps(state) + state.steps++ + + // Resolve input first so `applyOrchestratorHookResponses` has populated the + // stored `Interrupt.response` entries before the BeforeNodeCall hook reads them. + const nodeInput = this._resolveInputForScheduling( + node, + interruptResponsesByNode?.get(node.id), + contentInput, + state + ) + + const nodeSpan = this._tracer.withSpanContext(multiAgentSpan, () => + this._tracer.startNodeSpan({ nodeId: node.id, nodeType: node.type }) + ) + + const preResult = yield* this._runBeforeNodeCall(node, nodeSpan, state, invocationState) + if (preResult !== undefined) { + // Hook gated the node before it could run; surface the synthetic result + // through the queue so the main loop handles short-circuit and downstream + // scheduling uniformly with normal node results. + queue.push({ type: 'result', node, result: preResult }) + continue + } + + streams.set(node.id, this._streamNode(node, nodeInput, state, queue, nodeSpan, invocationState, cancelSignal)) + } + + await queue.wait() + while (queue.size > 0) { + const { data, ack } = queue.shift()! + + if (data.type === 'event') { + await this._hookRegistry.invokeCallbacks(data.event) + yield data.event + ack() + continue + } + + if (data.type === 'error') { + streams.delete(data.node.id) + ack() + throw data.error + } + + const { node, result: nodeResult } = data + streams.delete(node.id) + ack() + + state.results.push(nodeResult) + + if (interrupted) continue + + // Stop scheduling new nodes once any node has interrupted; in-flight siblings + // run to completion on their own. + if (nodeResult.status === Status.INTERRUPTED) { + interrupted = true + continue + } + + const ready = await this._findReady(node, state, streams, targets) + if (ready.length > 0) { + yield* this._emit( + new MultiAgentHandoffEvent({ + source: node.id, + targets: ready.map((n) => n.id), + state, + invocationState, + }) + ) + targets.push(...ready) + } + } + } + + result = new MultiAgentResult({ + results: state.results, + content: this._resolveContent(state), + duration: Date.now() - state.startTime, + }) + // Stash on interrupt so same-instance resume has state; otherwise start fresh. + if (result.status === Status.INTERRUPTED) { + this._pendingInterruptState = state + } else { + delete this._pendingInterruptState + delete state._pendingInput + } + } catch (error) { + caughtError = normalizeError(error) + throw caughtError + } finally { + if (execTimeoutHandle !== undefined) clearTimeout(execTimeoutHandle) + queue.dispose() + await Promise.allSettled(streams.values()) + + this._tracer.endMultiAgentSpan(multiAgentSpan, { + duration: Date.now() - state.startTime, + ...(result && { usage: result.usage }), + ...(caughtError && { error: caughtError }), + }) + + yield* this._emit(new AfterMultiAgentInvocationEvent({ orchestrator: this, state, invocationState })) + } + + yield* this._emit(new MultiAgentResultEvent({ result, invocationState })) + return result + } + + /** + * Invokes hook callbacks on an event, then yields it. + */ + private async *_emit(event: T): AsyncGenerator { + await this._hookRegistry.invokeCallbacks(event) + yield event + } + + /** + * Fires `BeforeNodeCallEvent` and handles hook-raised interrupts or cancels inline. + * Returns a synthetic NodeResult (INTERRUPTED or CANCELLED) when a hook gates the + * node, in which case the caller skips `_streamNode` and surfaces the result directly. + * Returns `undefined` when no hook gated the node and execution should proceed. + * + * Owns the `nodeSpan` on gated paths — `_streamNode` owns it on the ungated path. + * Yields the `NodeResultEvent` + `AfterNodeCallEvent` lifecycle pair on gated paths + * so observers see the same event sequence regardless of how the node terminated. + */ + private async *_runBeforeNodeCall( + node: Node, + nodeSpan: Span | null, + state: MultiAgentState, + invocationState: InvocationState + ): AsyncGenerator { + const nodeState = state.node(node.id)! + const beforeEvent = new BeforeNodeCallEvent({ orchestrator: this, state, nodeId: node.id, invocationState }) + try { + await this._hookRegistry.invokeCallbacks(beforeEvent) + } catch (error) { + if (error instanceof InterruptError) { + const result = recordHookInterrupt(node.id, nodeState) + yield beforeEvent + yield* this._emit(new NodeResultEvent({ nodeId: node.id, nodeType: node.type, state, result, invocationState })) + yield* this._emit(new AfterNodeCallEvent({ orchestrator: this, state, nodeId: node.id, invocationState })) + this._tracer.endNodeSpan(nodeSpan, { status: Status.INTERRUPTED, duration: result.duration }) + return result + } + throw error + } + yield beforeEvent + + if (beforeEvent.cancel) { + const message = typeof beforeEvent.cancel === 'string' ? beforeEvent.cancel : 'node cancelled by hook' + // Cancel path doesn't go through Node.stream, so do its INTERRUPTED cleanup here. + dropStaleInterruptedResult(node.id, nodeState, state) + const result = new NodeResult({ nodeId: node.id, status: Status.CANCELLED, duration: 0 }) + nodeState.status = Status.CANCELLED + nodeState.results.push(result) + yield* this._emit(new NodeCancelEvent({ nodeId: node.id, state, message, invocationState })) + yield* this._emit(new AfterNodeCallEvent({ orchestrator: this, state, nodeId: node.id, invocationState })) + this._tracer.endNodeSpan(nodeSpan, { status: Status.CANCELLED, duration: 0 }) + return result + } + + return undefined + } + + /** + * Runs a node whose `BeforeNodeCallEvent` already fired without a hook gating it + * (interrupt or cancel are handled by `_runBeforeNodeCall` before this coroutine + * is spawned). Takes ownership of the already-started `nodeSpan` and ends it. + */ + private async _streamNode( + node: Node, + input: MultiAgentInput, + state: MultiAgentState, + queue: Queue, + nodeSpan: Span | null, + invocationState: InvocationState, + executionSignal?: AbortSignal + ): Promise { + // Per-node timeout only applies to AgentNode; a nested MultiAgentNode manages + // its own node-level timeouts. + const nodeTimeout = node instanceof AgentNode ? (node.timeout ?? this.config.nodeTimeout) : Infinity + const nodeTimeoutController = Number.isFinite(nodeTimeout) ? new AbortController() : undefined + const nodeTimeoutHandle = nodeTimeoutController + ? setTimeout(() => nodeTimeoutController.abort(), nodeTimeout) + : undefined + const signals = [executionSignal, nodeTimeoutController?.signal].filter((s): s is AbortSignal => s !== undefined) + const cancelSignal = signals.length > 0 ? AbortSignal.any(signals) : undefined + + try { + const gen = this._tracer.withSpanContext(nodeSpan, () => + node.stream(input, state, { invocationState, ...(cancelSignal && { cancelSignal }) }) + ) + let next = await this._tracer.withSpanContext(nodeSpan, () => gen.next()) + while (!next.done) { + await queue.send({ type: 'event', node, event: next.value }) + next = await this._tracer.withSpanContext(nodeSpan, () => gen.next()) + } + + if (nodeTimeoutController?.signal.aborted) { + throw new Error( + `node_timeout=<${nodeTimeout}>, node_id=<${node.id}>, graph_id=<${this.id}> | node exceeded wall-clock budget` + ) + } + + const result = next.value + this._tracer.endNodeSpan(nodeSpan, { status: result.status, duration: result.duration, usage: result.usage }) + queue.push({ type: 'result', node, result }) + + await queue.send({ + type: 'event', + node, + event: new AfterNodeCallEvent({ orchestrator: this, state, nodeId: node.id, invocationState }), + }) + } catch (error) { + const nodeError = normalizeError(error) + this._tracer.endNodeSpan(nodeSpan, { error: nodeError }) + + await queue.send({ + type: 'event', + node, + event: new AfterNodeCallEvent({ + orchestrator: this, + state, + nodeId: node.id, + invocationState, + error: nodeError, + }), + }) + queue.push({ + type: 'error', + node, + error: nodeError, + }) + } finally { + if (nodeTimeoutHandle !== undefined) clearTimeout(nodeTimeoutHandle) + } + } + + private _validateConfig(): void { + if (this.config.maxConcurrency < 1) { + throw new Error(`max_concurrency=<${this.config.maxConcurrency}> | must be at least 1`) + } + if (this.config.maxSteps < 1) { + throw new Error(`max_steps=<${this.config.maxSteps}> | must be at least 1`) + } + if (this.config.timeout < 1) { + throw new Error(`timeout=<${this.config.timeout}> | must be at least 1`) + } + if (this.config.nodeTimeout < 1) { + throw new Error(`node_timeout=<${this.config.nodeTimeout}> | must be at least 1`) + } + } + + private _validateSources(): void { + if (this._sources.length === 0) { + throw new Error('graph has no source nodes') + } + + const visited = new Set() + const adjacency = new Map() + for (const edge of this.edges) { + const targets = adjacency.get(edge.source.id) ?? [] + targets.push(edge.target.id) + adjacency.set(edge.source.id, targets) + } + + const queue = this._sources.map((n) => n.id) + while (queue.length > 0) { + const id = queue.shift()! + if (visited.has(id)) continue + visited.add(id) + for (const target of adjacency.get(id) ?? []) { + queue.push(target) + } + } + + for (const id of this.nodes.keys()) { + if (!visited.has(id)) { + throw new Error(`node_id=<${id}> | unreachable from any source node`) + } + } + } + + private _resolveNodes(definitions: NodeDefinition[]): Map { + const nodes = new Map() + + for (const definition of definitions) { + let node: Node + + if (definition instanceof Node) { + node = definition + } else if ('orchestrator' in definition) { + node = new MultiAgentNode(definition) + } else if ('agent' in definition) { + node = new AgentNode(definition) + } else if (definition instanceof Graph || definition instanceof Swarm) { + node = new MultiAgentNode({ orchestrator: definition }) + } else { + node = new AgentNode({ agent: definition as InvokableAgent }) + } + + if (nodes.has(node.id)) { + throw new Error(`node_id=<${node.id}> | duplicate node id`) + } + nodes.set(node.id, node) + } + + return nodes + } + + private _resolveEdges(definitions: EdgeDefinition[]): Edge[] { + const edges: Edge[] = [] + for (const definition of definitions) { + const [sourceId, targetId, handler] = Array.isArray(definition) + ? [definition[0], definition[1], undefined] + : [definition.source, definition.target, definition.handler] + + const source = this.nodes.get(sourceId) + const target = this.nodes.get(targetId) + if (!source) { + throw new Error(`source=<${sourceId}> | edge references unknown source node`) + } + if (!target) { + throw new Error(`target=<${targetId}> | edge references unknown target node`) + } + edges.push(new Edge({ source, target, ...(handler && { handler }) })) + } + return edges + } + + private _resolveSources(sourceIds?: string[]): Node[] { + if (sourceIds) { + const sources: Node[] = [] + for (const id of sourceIds) { + const node = this.nodes.get(id) + if (!node) { + throw new Error(`source=<${id}> | source references unknown node`) + } + sources.push(node) + } + return sources + } + + const targetIds = new Set(this.edges.map((e) => e.target.id)) + return [...this.nodes.values()].filter((node) => !targetIds.has(node.id)) + } + + /** + * Identifies terminus nodes and returns their combined content. + * A terminus node is where an execution path ended: completed with no + * downstream progress, or failed/cancelled. + */ + private _resolveContent(state: MultiAgentState): ContentBlock[] { + for (const [id, ns] of state.nodes.entries()) { + if (ns.status === Status.FAILED || ns.status === Status.CANCELLED) { + ns.terminus = true + } else if (ns.status === Status.COMPLETED) { + ns.terminus = !this.edges + .filter((e) => e.source.id === id) + .some((e) => state.node(e.target.id)?.status !== Status.PENDING) + } + } + return [...state.nodes.values()].filter((ns) => ns.terminus).flatMap((ns) => ns.content) + } + + /** + * Chooses the input for a node about to be scheduled, handling the three resume cases: + * routed orchestrator-hook responses (forward leftovers to the agent), routed responses + * fully consumed by the hook (replay the original invocation input), and fresh runs + * (dependency-merged). Falls back to an empty input with a warning if a custom + * SessionManager dropped `_pendingInput`. + */ + private _resolveInputForScheduling( + node: Node, + routed: InterruptResponseContent[] | undefined, + contentInput: MultiAgentContentInput | undefined, + state: MultiAgentState + ): MultiAgentInput { + if (routed) { + const nodeState = state.node(node.id) + if (!nodeState) { + throw new Error( + `node_id=<${node.id}>, graph_id=<${this.id}> | routed interrupt response targets a node missing from state; topology changed between save and resume?` + ) + } + const forwarded = applyOrchestratorHookResponses(nodeState, routed) + if (forwarded.length > 0) return forwarded + } + if (contentInput === undefined) { + logger.warn(`node_id=<${node.id}>, graph_id=<${this.id}> | no pending input on resume; using empty`) + return this._resolveNodeInput(node, '', state) + } + return this._resolveNodeInput(node, contentInput, state) + } + + /** + * Builds the input for a node by combining the original task with dependency outputs. + * + * Only called for non-resume executions: the caller routes resume responses directly + * to interrupted nodes without going through dependency resolution, so this helper + * never sees `InterruptResponseContent[]`. + */ + private _resolveNodeInput(node: Node, input: MultiAgentContentInput, state: MultiAgentState): MultiAgentInput { + const deps: ContentBlock[] = [] + for (const edge of this.edges.filter((e) => e.target.id === node.id)) { + const ns = state.node(edge.source.id)! + if (ns.content.length > 0) { + deps.push(new TextBlock(`[node: ${edge.source.id}]`), ...ns.content) + } + } + + if (deps.length === 0) return input + + const blocks = + typeof input === 'string' + ? [new TextBlock(input)] + : input.map((b) => ('type' in b ? (b as ContentBlock) : contentBlockFromData(b))) + return [...blocks, ...deps] + } + + /** + * Finds nodes that should execute on resume from a restored {@link MultiAgentState}. + * + * Any node that did not complete is a candidate for re-execution, provided its + * dependencies are all COMPLETED and edge conditions are satisfied. This covers: + * - PENDING nodes that never started + * - EXECUTING/FAILED/CANCELLED nodes from the previous run + * - Source nodes (no incoming edges) that are not COMPLETED + * + * Works for all node types including {@link AgentNode} and {@link MultiAgentNode} + * (subgraphs/swarms). A `MultiAgentNode` that didn't complete will be re-executed + * from scratch — its inner orchestrator manages its own state independently. + * + * @returns Array of ready nodes, or `undefined` if state was not restored (fresh start) + */ + private async _findResumeTargets(state: MultiAgentState): Promise { + // No completed nodes in state means fresh start (state was not restored) + const hasCompletedNodes = [...state.nodes.values()].some((ns) => ns.status === Status.COMPLETED) + if (!hasCompletedNodes) return undefined + + const ready: Node[] = [] + for (const [id, node] of this.nodes) { + if (state.node(id)?.status === Status.COMPLETED) continue + + const incoming = this.edges.filter((e) => e.target.id === id) + if (incoming.length === 0) { + // Source node that hasn't completed + ready.push(node) + } else if (await this._allDependenciesSatisfied(incoming, state)) { + ready.push(node) + } + } + + if (ready.length > 0) { + logger.debug( + `resume_targets=<${ready.map((n) => n.id).join(', ')}>, prior_steps=<${state.steps}> | resuming graph from restored state` + ) + return ready + } + + logger.debug('all nodes completed in restored state | starting fresh') + return undefined + } + + /** + * Checks whether all incoming edges have completed sources with satisfied conditions. + */ + private async _allDependenciesSatisfied(incoming: Edge[], state: MultiAgentState): Promise { + for (const edge of incoming) { + if (state.node(edge.source.id)?.status !== Status.COMPLETED) return false + if (!(await edge.handler(state))) return false + } + return true + } + + private _checkSteps(state: MultiAgentState): void { + if (state.steps >= this.config.maxSteps) { + throw new Error(`steps=<${state.steps}> | max steps reached`) + } + } + + /** + * Finds downstream nodes that are ready to execute after a node completes. + * A target is ready when all its incoming edge sources are COMPLETED and all edge handlers return true. + * + * @param node - The node that just completed execution. + * @param state - Current multi-agent execution state. + * @param streams - Map of node IDs to their in-flight execution promises. + * @param targets - Nodes already queued for execution. + * @returns Nodes that are ready to execute. + */ + private async _findReady( + node: Node, + state: MultiAgentState, + streams: ReadonlyMap>, + targets: readonly Node[] + ): Promise { + if (state.node(node.id)?.status !== Status.COMPLETED) return [] + + const ready: Node[] = [] + + for (const edge of this.edges.filter((e) => e.source.id === node.id)) { + // skip if the target is already running or queued + if (streams.has(edge.target.id) || targets.some((n) => n.id === edge.target.id)) continue + + const incoming = this.edges.filter((e) => e.target.id === edge.target.id) + if (await this._allDependenciesSatisfied(incoming, state)) { + ready.push(edge.target) + } + } + + return ready + } +} diff --git a/strands-ts/src/multiagent/index.ts b/strands-ts/src/multiagent/index.ts new file mode 100644 index 0000000000..2075cec649 --- /dev/null +++ b/strands-ts/src/multiagent/index.ts @@ -0,0 +1,43 @@ +/** + * Multi-agent orchestration module. + */ + +export { MultiAgentState, NodeState, Status, NodeResult, MultiAgentResult } from './state.js' +export type { NodeResultUpdate, ResultStatus } from './state.js' + +export { Node, AgentNode, MultiAgentNode } from './nodes.js' +export type { + NodeConfig, + NodeInputOptions, + AgentNodeOptions, + MultiAgentNodeOptions, + NodeDefinition, + NodeType, +} from './nodes.js' + +export { + MultiAgentInitializedEvent, + BeforeMultiAgentInvocationEvent, + AfterMultiAgentInvocationEvent, + BeforeNodeCallEvent, + AfterNodeCallEvent, + NodeStreamUpdateEvent, + NodeResultEvent, + NodeCancelEvent, + MultiAgentHandoffEvent, + MultiAgentResultEvent, +} from './events.js' +export type { MultiAgentStreamEvent, NodeStreamUpdateInnerEvent } from './events.js' + +export { Edge } from './edge.js' +export type { EdgeHandler, EdgeDefinition } from './edge.js' + +export { Graph } from './graph.js' +export type { GraphConfig, GraphOptions } from './graph.js' + +export { Swarm } from './swarm.js' +export type { SwarmConfig, SwarmNodeDefinition, SwarmOptions } from './swarm.js' + +export type { MultiAgentPlugin } from './plugins.js' + +export type { MultiAgent, MultiAgentInput, MultiAgentInvokeOptions } from './multiagent.js' diff --git a/strands-ts/src/multiagent/multiagent.ts b/strands-ts/src/multiagent/multiagent.ts new file mode 100644 index 0000000000..d280a1efdd --- /dev/null +++ b/strands-ts/src/multiagent/multiagent.ts @@ -0,0 +1,216 @@ +import type { InvocationState, InvokeArgs } from '../types/agent.js' +import type { Message, MessageData } from '../types/messages.js' +import type { HookableEvent } from '../hooks/events.js' +import type { HookCallback, HookableEventConstructor, HookCleanup } from '../hooks/types.js' +import type { InterruptResponseContentData } from '../types/interrupt.js' +import { InterruptResponseContent, isInterruptResponseContent } from '../types/interrupt.js' +import type { MultiAgentStreamEvent } from './events.js' +import { NodeResult, Status } from './state.js' +import type { MultiAgentResult, MultiAgentState, NodeState } from './state.js' + +/** + * Input type for multi-agent orchestrators. Excludes `Message[]` / `MessageData[]` + * since orchestrators route content blocks between nodes rather than replaying raw + * conversation history. + * + * Accepts `InterruptResponseContent[]` / `InterruptResponseContentData[]` for resuming + * from an interrupted run — orchestrators detect resume input at the entry point and + * route responses to the interrupted nodes rather than flowing through dependency + * resolution. + */ +export type MultiAgentInput = Exclude + +/** + * The non-resume subset of {@link MultiAgentInput}. Internal orchestrator helpers that + * participate in dependency resolution / handoff routing accept this narrower type so + * they don't need to re-check for {@link InterruptResponseContent} entries at each call. + * + * @internal + */ +export type MultiAgentContentInput = Exclude< + MultiAgentInput, + InterruptResponseContent[] | InterruptResponseContentData[] +> + +/** + * Options for a single multi-agent orchestrator invocation. + */ +export interface MultiAgentInvokeOptions { + /** + * Per-invocation state forwarded to every node's child agent. Mutable — + * one node's hooks/tools can read state written by a previous node. See + * {@link InvocationState} for details. Defaults to `{}` when omitted. + */ + invocationState?: InvocationState + + /** + * Cancellation signal forwarded to every node (and into any nested orchestrators + * via `MultiAgentNode`). Composed with the orchestrator's own timeout and + * short-circuit signals, matching {@link InvokeOptions.cancelSignal} on the + * single-agent path. Cooperative — honored by nodes that forward it to their + * underlying agents/tools. + * + * When this signal aborts, the orchestrator throws rather than returning a clean + * result. This matches single-agent behavior: external cancellation is treated as + * an exceptional exit, not a normal terminal state. + */ + cancelSignal?: AbortSignal +} + +/** + * Interface for any multi-agent orchestrator that can stream execution. + * Implement this interface to create custom orchestration patterns that can be + * composed as nodes within other orchestrators via {@link MultiAgentNode}. + */ +export interface MultiAgent { + /** Unique identifier for this orchestrator. */ + readonly id: string + + /** + * Execute the orchestrator and return the final result. + * @param input - Input to pass to the orchestrator + * @param options - Optional per-invocation options + * @returns The aggregate result from all executed nodes + */ + invoke(input: MultiAgentInput, options?: MultiAgentInvokeOptions): Promise + + /** + * Execute the orchestrator and stream events as they occur. + * @param input - Input to pass to the orchestrator + * @param options - Optional per-invocation options + * @returns Async generator yielding events and returning the final result + */ + stream( + input: MultiAgentInput, + options?: MultiAgentInvokeOptions + ): AsyncGenerator + + /** + * Register a hook callback for a specific orchestrator event type. + * + * @param eventType - The event class constructor to register the callback for + * @param callback - The callback function to invoke when the event occurs + * @returns Cleanup function that removes the callback when invoked + */ + addHook(eventType: HookableEventConstructor, callback: HookCallback): HookCleanup +} + +/** + * Detects whether a {@link MultiAgentInput} is a resume from an interrupted run and + * normalizes its entries to {@link InterruptResponseContent} instances. + * + * Returns `undefined` for fresh input (string / content blocks / empty array). + */ +export function extractResumeResponses(input: MultiAgentInput): InterruptResponseContent[] | undefined { + if (!Array.isArray(input) || input.length === 0) return undefined + if (!isInterruptResponseContent(input[0])) return undefined + + const responses: InterruptResponseContent[] = [] + for (const entry of input) { + if (entry instanceof InterruptResponseContent) { + responses.push(entry) + } else if (isInterruptResponseContent(entry)) { + responses.push(InterruptResponseContent.fromJSON(entry as InterruptResponseContentData)) + } else { + throw new TypeError('Must resume from interrupt with a list of interruptResponse content blocks only') + } + } + return responses +} + +/** + * Groups a flat list of interrupt responses by the node that raised each interrupt. + * + * For each response, finds the node whose `NodeState.interrupts` contains an entry + * with a matching id. Ids are globally unique (derived from model-assigned + * `toolUseId`s) so each response maps to exactly one node. Nested orchestrators + * carry their subtree's interrupts on the wrapping `MultiAgentNode`'s state, so a + * matching response is forwarded as-is to the nested orchestrator, which does its + * own grouping recursively. + * + * @throws Error if any response's interrupt id does not match any tracked node + */ +export function groupInterruptResponsesByNode( + responses: InterruptResponseContent[], + state: MultiAgentState +): Map { + const grouped = new Map() + for (const response of responses) { + const id = response.interruptResponse.interruptId + let target: string | undefined + for (const [nodeId, nodeState] of state.nodes) { + if (nodeState.interrupts.some((i) => i.id === id)) { + target = nodeId + break + } + } + if (!target) { + throw new Error(`interrupt_id=<${id}> | no node found with matching interrupt`) + } + const bucket = grouped.get(target) ?? [] + bucket.push(response) + grouped.set(target, bucket) + } + return grouped +} + +/** + * Removes a stale INTERRUPTED result for the given node from both per-node history + * and the orchestrator-level aggregate so a fresh result (from resume or cancel) + * replaces it cleanly. No-op if the node isn't in an INTERRUPTED state. + */ +export function dropStaleInterruptedResult(nodeId: string, nodeState: NodeState, state: MultiAgentState): void { + if (nodeState.status !== Status.INTERRUPTED) return + if (nodeState.results[nodeState.results.length - 1]?.status === Status.INTERRUPTED) { + nodeState.results.pop() + } + const idx = state.results.findIndex((r) => r.nodeId === nodeId && r.status === Status.INTERRUPTED) + if (idx >= 0) state.results.splice(idx, 1) +} + +/** + * Records a hook-raised interrupt on a node that hadn't started executing: builds + * the INTERRUPTED {@link NodeResult}, transitions `nodeState.status`, and appends + * the result to `nodeState.results`. Returns the result so callers can route it + * into their own queue/lifecycle machinery. + * + * Shared between Graph and Swarm so their hook-interrupt branches don't drift. + */ +export function recordHookInterrupt(nodeId: string, nodeState: NodeState): NodeResult { + const result = new NodeResult({ + nodeId, + status: Status.INTERRUPTED, + duration: Date.now() - nodeState.startTime, + interrupts: nodeState.interrupts, + }) + nodeState.status = Status.INTERRUPTED + nodeState.results.push(result) + return result +} + +/** + * Applies interrupt responses to a node's own orchestrator-level interrupts and + * returns the remaining responses — those bound for the child agent's interrupts. + * + * Orchestrator hooks (source `'multiagent-hook'`) store their interrupts on + * `NodeState.interrupts` directly; the hook re-runs on resume and reads the stored + * response. Agent-level interrupts aren't answerable here — they flow to the child + * agent as resume input and are applied by the agent's own interrupt machinery. + */ +export function applyOrchestratorHookResponses( + nodeState: NodeState, + responses: InterruptResponseContent[] +): InterruptResponseContent[] { + const forwarded: InterruptResponseContent[] = [] + for (const response of responses) { + const local = nodeState.interrupts.find( + (i) => i.id === response.interruptResponse.interruptId && i.source === 'multiagent-hook' + ) + if (local) { + local.response = response.interruptResponse.response + } else { + forwarded.push(response) + } + } + return forwarded +} diff --git a/strands-ts/src/multiagent/nodes.ts b/strands-ts/src/multiagent/nodes.ts new file mode 100644 index 0000000000..a512d0276e --- /dev/null +++ b/strands-ts/src/multiagent/nodes.ts @@ -0,0 +1,410 @@ +import { Agent } from '../agent/agent.js' +import type { InvocationState, InvokeOptions, InvokableAgent, AgentStreamEvent } from '../types/agent.js' +import type { MultiAgentInput } from './multiagent.js' +import { dropStaleInterruptedResult } from './multiagent.js' +import type { MultiAgentStreamEvent } from './events.js' +import { NodeStreamUpdateEvent, NodeResultEvent } from './events.js' +import { NodeResult, Status } from './state.js' +import type { MultiAgentState, NodeResultUpdate } from './state.js' +import type { MultiAgent } from './multiagent.js' +import { logger } from '../logging/logger.js' +import type { z } from 'zod' +import { normalizeError } from '../errors.js' +import { omitUndefined } from '../types/json.js' + +/** + * Known node type identifiers with extensibility for custom nodes. + */ +export type NodeType = 'agentNode' | 'multiAgentNode' | (string & {}) + +/** + * Configuration for a node execution. + */ +export interface NodeConfig { + /** + * Optional description of what this node does. + */ + description?: string +} + +/** + * Per-invocation options passed from the orchestrator to a node. + */ +export interface NodeInputOptions { + /** + * Structured output schema for this node invocation. + */ + structuredOutputSchema?: z.ZodSchema + + /** + * Per-invocation state forwarded to the node's underlying agent. See + * {@link InvocationState}. Shared by reference across all nodes so one node's + * hooks/tools can read state written by a previous node. + */ + invocationState?: InvocationState + + /** + * Cancellation signal forwarded to the node's underlying agent. Used by + * orchestrators to enforce per-node timeouts or propagate external cancellation. + */ + cancelSignal?: AbortSignal +} + +/** + * Abstract base class for all multi-agent orchestration nodes. + * + * Uses the template method pattern: {@link stream} handles orchestration + * boilerplate (duration measurement, status tracking, error capture) and + * delegates to {@link handle} for node-specific execution logic. + */ +export abstract class Node { + readonly type: string = 'node' + /** Unique identifier for this node within the orchestration. */ + readonly id: string + /** Per-node configuration. */ + readonly config: NodeConfig + + /** + * @param id - Unique identifier for this node within the orchestration + * @param config - Per-node configuration + */ + constructor(id: string, config: NodeConfig) { + this.id = id + this.config = config + } + + /** + * Execute the node. Handles duration measurement, error capture, + * and delegates to handle() for node-specific logic. + * + * @param input - Input to pass to the node (string or content blocks) + * @param state - The current multi-agent state + * @param options - Per-invocation options from the orchestrator + * @returns Async generator yielding streaming events and returning a NodeResult + */ + async *stream( + input: MultiAgentInput, + state: MultiAgentState, + options?: NodeInputOptions + ): AsyncGenerator { + const nodeState = state.node(this.id)! + + // Resuming from INTERRUPTED: drop the stale result so the fresh one replaces it. + dropStaleInterruptedResult(this.id, nodeState, state) + + nodeState.status = Status.EXECUTING + nodeState.startTime = Date.now() + + // Resolve invocationState once — the same reference is threaded into handle() + // and into NodeResultEvent so callbacks see one object for the whole node run. + const invocationState: InvocationState = options?.invocationState ?? {} + const resolvedOptions: NodeInputOptions = { ...options, invocationState } + + let result: NodeResult + try { + const update = yield* this.handle(input, state, resolvedOptions) + const defaultStatus = update.interrupts && update.interrupts.length > 0 ? Status.INTERRUPTED : Status.COMPLETED + result = new NodeResult({ + nodeId: this.id, + status: defaultStatus, + duration: Date.now() - nodeState.startTime, + content: [], + ...update, + }) + } catch (error) { + // Orchestrator cancellation (short-circuit or external) maps thrown errors to + // CANCELLED — node was stopped, not broken. + const status = options?.cancelSignal?.aborted ? Status.CANCELLED : Status.FAILED + result = new NodeResult({ + nodeId: this.id, + status, + duration: Date.now() - nodeState.startTime, + error: normalizeError(error), + }) + if (status === Status.FAILED) { + logger.warn(`node_id=<${this.id}>, error=<${result.error?.message}> | node execution failed`) + } + } finally { + nodeState.status = result!.status + nodeState.results.push(result!) + nodeState.interrupts = result!.interrupts ?? [] + // Clear the stored snapshot on non-INTERRUPTED terminal states; `handle()` + // repopulates it above if this run itself interrupted. + if (result!.status !== Status.INTERRUPTED) { + delete nodeState.interruptedSnapshot + } + } + + yield new NodeResultEvent({ + nodeId: this.id, + nodeType: this.type, + state, + result, + invocationState, + }) + return result + } + + /** + * Node-specific execution logic implemented by subclasses. + * + * @param input - Input to process (string or content blocks) + * @param state - The current multi-agent state + * @param options - Per-invocation options from the orchestrator + * @returns Async generator yielding streaming events and returning a partial result + */ + abstract handle( + input: MultiAgentInput, + state: MultiAgentState, + options?: NodeInputOptions + ): AsyncGenerator +} + +/** + * Options for creating an {@link AgentNode}. + */ +export interface AgentNodeOptions { + /** The agent to wrap as a node. */ + agent: InvokableAgent + /** + * Per-node wall-clock ceiling in milliseconds. Overrides the orchestrator's + * default node timeout. Cancellation is cooperative — a tool that neither + * polls its cancel signal nor forwards it to a cancellable API can run past + * this deadline. + */ + timeout?: number + /** + * When `true`, the wrapped agent accumulates state (messages, appState, + * modelState) across node executions. Useful for graph patterns where a + * node is revisited and should build on its previous work (e.g., an + * analyst that accumulates findings, or iterative refinement). + * + * When `false` (default), the agent's state is snapshotted before each + * execution and restored in `finally`, so the node is stateless across + * visits. + * + * Throws at construction time when set to `true` with a non-`Agent` + * `InvokableAgent`, since snapshot/restore only applies to `Agent` instances. + */ + preserveContext?: boolean +} + +/** + * Node that wraps an {@link InvokableAgent} instance for multi-agent orchestration. + * + * By default, when the wrapped agent is an {@link Agent} instance, its internal + * state is snapshot/restored around each execution so it remains unchanged + * after the node completes. Pass `preserveContext: true` to opt out and let the + * wrapped agent accumulate state across node executions. + */ +export class AgentNode extends Node { + readonly type = 'agentNode' as const + private readonly _agent: InvokableAgent + /** + * Per-node wall-clock ceiling in milliseconds. When set, overrides the orchestrator's + * `nodeTimeout` for this node. Undefined means "fall back to the orchestrator's setting." + * See {@link AgentNodeOptions.timeout}. + */ + readonly timeout?: number + /** + * Whether the wrapped agent retains state across node executions. + * See {@link AgentNodeOptions.preserveContext}. + */ + readonly preserveContext: boolean + + constructor(options: AgentNodeOptions) { + const { agent, timeout, preserveContext, ...config } = options + + super(agent.id, { + ...config, + ...(agent.description !== undefined && { description: agent.description }), + }) + + this._agent = agent + if (timeout !== undefined) { + if (timeout < 1) { + throw new Error(`timeout=<${timeout}>, node_id=<${agent.id}> | must be at least 1`) + } + this.timeout = timeout + } + if (preserveContext && !(agent instanceof Agent)) { + throw new Error( + `node_id=<${agent.id}> | preserveContext=true requires an Agent instance; non-Agent InvokableAgents cannot be snapshotted` + ) + } + this.preserveContext = preserveContext ?? false + } + + get agent(): InvokableAgent { + return this._agent + } + + /** + * Executes the wrapped agent, yielding each agent streaming event + * wrapped in a {@link NodeStreamUpdateEvent}. + * + * @param input - Input to pass to the agent + * @param state - The current multi-agent state + * @param options - Per-invocation options from the orchestrator + * @returns Async generator yielding streaming events and returning the agent's content blocks + */ + async *handle( + input: MultiAgentInput, + state: MultiAgentState, + options?: NodeInputOptions + ): AsyncGenerator { + // Resolve once per handle() call — Node.stream() normally supplies this; + // handle() is public API, so direct callers get per-call state. + const invocationState: InvocationState = options?.invocationState ?? {} + + // Only Agent instances support snapshot/restore for state isolation. + // When `preserveContext` is set, skip the snapshot/restore cycle so the agent + // accumulates state across node executions. + const isAgent = this._agent instanceof Agent + const preRunSnapshot = + !this.preserveContext && isAgent ? this._agent.takeSnapshot({ preset: 'session' }) : undefined + + // Rehydrate agent state from a prior INTERRUPTED run (messages + interrupt state). + // Independent of `preserveContext`: a paused run always resumes from where it left off. + const nodeState = state.node(this.id) + if (isAgent && nodeState?.interruptedSnapshot) { + this._agent.loadSnapshot(nodeState.interruptedSnapshot) + } + + try { + const invokeOptions: InvokeOptions = { + ...(options?.structuredOutputSchema && { structuredOutputSchema: options.structuredOutputSchema }), + ...(options?.cancelSignal && { cancelSignal: options.cancelSignal }), + invocationState, + } + + const gen = this._agent.stream(input, invokeOptions) + let next = await gen.next() + while (!next.done) { + yield new NodeStreamUpdateEvent({ + nodeId: this.id, + nodeType: this.type, + state, + inner: isAgent + ? { source: 'agent', event: next.value as AgentStreamEvent } + : { source: 'custom', event: next.value }, + invocationState, + }) + next = await gen.next() + } + + const agentResult = next.value + const interrupted = + agentResult.stopReason === 'interrupt' && agentResult.interrupts && agentResult.interrupts.length > 0 + + // Capture post-interrupt state for the next resume cycle. Only Agent instances + // are snapshottable. + if (interrupted && isAgent && nodeState) { + nodeState.interruptedSnapshot = this._agent.takeSnapshot({ preset: 'session' }) + } + + return omitUndefined({ + content: agentResult.lastMessage.content, + structuredOutput: 'structuredOutput' in agentResult ? agentResult.structuredOutput : undefined, + usage: agentResult.metrics?.accumulatedUsage, + interrupts: interrupted ? agentResult.interrupts : undefined, + }) + } finally { + // Restore pre-run state — keeps the agent observably unchanged across runs. + if (preRunSnapshot) { + ;(this._agent as Agent).loadSnapshot(preRunSnapshot) + } + } + } +} + +/** + * Options for creating a {@link MultiAgentNode}. + */ +export interface MultiAgentNodeOptions extends NodeConfig { + /** The orchestrator to wrap as a node. */ + orchestrator: MultiAgent +} + +/** + * Node that wraps a multi-agent orchestrator (e.g. Graph) for nested composition. + * + * Inner {@link NodeStreamUpdateEvent}s pass through to preserve the original + * node's identity. All other events are wrapped in a new {@link NodeStreamUpdateEvent} + * tagged with this node's identity. + */ +export class MultiAgentNode extends Node { + readonly type = 'multiAgentNode' as const + private readonly _orchestrator: MultiAgent + + constructor(options: MultiAgentNodeOptions) { + const { orchestrator, ...config } = options + super(orchestrator.id, config) + this._orchestrator = orchestrator + } + + get orchestrator(): MultiAgent { + return this._orchestrator + } + + /** + * Executes the wrapped orchestrator. Inner {@link NodeStreamUpdateEvent}s + * pass through as-is; all other events are wrapped in a new + * {@link NodeStreamUpdateEvent} tagged with this node's identity. + * + * @param input - Input to pass to the orchestrator + * @param state - The current multi-agent state + * @param options - Per-invocation options. `invocationState` is forwarded to the + * nested orchestrator; `structuredOutputSchema` is not applicable here. + * @returns Async generator yielding streaming events and returning the orchestrator's content + */ + async *handle( + input: MultiAgentInput, + state: MultiAgentState, + options?: NodeInputOptions + ): AsyncGenerator { + // Resolve once per handle() call — Node.stream() normally supplies this; + // handle() is public API, so direct callers get per-call state. + const invocationState: InvocationState = options?.invocationState ?? {} + + const gen = this._orchestrator.stream(input, { + invocationState, + ...(options?.cancelSignal && { cancelSignal: options.cancelSignal }), + }) + let next = await gen.next() + while (!next.done) { + const event = next.value + if (event.type === 'nodeStreamUpdateEvent') { + yield event + } else { + yield new NodeStreamUpdateEvent({ + nodeId: this.id, + nodeType: this.type, + state, + inner: { source: 'multiAgent', event }, + invocationState, + }) + } + next = await gen.next() + } + const innerResult = next.value + const interrupted = innerResult.interrupts && innerResult.interrupts.length > 0 + + return omitUndefined({ + content: innerResult.content, + usage: innerResult.usage, + status: innerResult.status !== Status.COMPLETED ? innerResult.status : undefined, + error: innerResult.error, + interrupts: interrupted ? innerResult.interrupts : undefined, + }) + } +} + +/** + * A node definition accepted by orchestration constructors. + * + * Pass an {@link InvokableAgent} or {@link MultiAgent} directly for the simple case, + * use typed options objects for per-node configuration, or provide pre-built + * {@link Node} instances for full control. + */ +export type NodeDefinition = InvokableAgent | MultiAgent | Node | AgentNodeOptions | MultiAgentNodeOptions diff --git a/strands-ts/src/multiagent/plugins.ts b/strands-ts/src/multiagent/plugins.ts new file mode 100644 index 0000000000..d127072d02 --- /dev/null +++ b/strands-ts/src/multiagent/plugins.ts @@ -0,0 +1,90 @@ +/** + * Plugin interface and registry for extending multi-agent orchestrator functionality. + * + * This module defines the MultiAgentPlugin interface and MultiAgentPluginRegistry, + * which provide a composable way to add behavior to multi-agent orchestrators (e.g. Swarm, Graph) + * through hook registration and custom initialization. + */ + +import type { MultiAgent } from './multiagent.js' + +/** + * Interface for objects that implement multi-agent orchestrator plugin functionality. + * + * MultiAgentPlugins provide a composable way to add behavior to orchestrators + * by registering hook callbacks in their `initMultiAgent` method. + * + * @example + * ```typescript + * class LoggingPlugin implements MultiAgentPlugin { + * get name(): string { + * return 'logging-plugin' + * } + * + * initMultiAgent(orchestrator: MultiAgent): void { + * orchestrator.addHook(BeforeNodeCallEvent, (event) => { + * console.log(`Node ${event.nodeId} starting`) + * }) + * } + * } + * + * const swarm = new Swarm({ + * nodes: [agentA, agentB], + * start: 'agentA', + * plugins: [new LoggingPlugin()], + * }) + * ``` + */ +export interface MultiAgentPlugin { + /** + * A stable string identifier for the plugin. + * Used for logging, duplicate detection, and plugin management. + */ + readonly name: string + + /** + * Initialize the plugin with the orchestrator instance. + * + * Implement this method to register hooks and perform custom initialization. + * + * @param orchestrator - The orchestrator this plugin is being attached to + */ + initMultiAgent(orchestrator: MultiAgent): void | Promise +} + +/** + * Registry for managing plugins attached to a multi-agent orchestrator. + * + * Holds pending plugins and initializes them on first use. + * Handles duplicate detection and calls each plugin's initMultiAgent method. + */ +export class MultiAgentPluginRegistry { + private readonly _plugins: Map + private readonly _pending: MultiAgentPlugin[] + + constructor(plugins: MultiAgentPlugin[] = []) { + this._plugins = new Map() + this._pending = [...plugins] + } + + /** + * Initialize all pending plugins with the orchestrator. + * Safe to call multiple times — only runs once. + * + * @param orchestrator - The orchestrator instance to initialize plugins with + */ + async initialize(orchestrator: MultiAgent): Promise { + while (this._pending.length > 0) { + const plugin = this._pending.shift()! + await this._addAndInit(plugin, orchestrator) + } + } + + private async _addAndInit(plugin: MultiAgentPlugin, orchestrator: MultiAgent): Promise { + if (this._plugins.has(plugin.name)) { + throw new Error(`plugin_name=<${plugin.name}> | plugin already registered`) + } + this._plugins.set(plugin.name, plugin) + await plugin.initMultiAgent(orchestrator) + } +} diff --git a/strands-ts/src/multiagent/queue.ts b/strands-ts/src/multiagent/queue.ts new file mode 100644 index 0000000000..c51e0a6618 --- /dev/null +++ b/strands-ts/src/multiagent/queue.ts @@ -0,0 +1,98 @@ +import type { Node } from './nodes.js' +import type { MultiAgentStreamEvent } from './events.js' +import type { NodeResult } from './state.js' + +/** + * Data produced by a running node: a streaming event, a completion signal, or an error. + */ +export type QueueData = + | { type: 'event'; node: Node; event: MultiAgentStreamEvent } + | { type: 'result'; node: Node; result: NodeResult } + | { type: 'error'; node: Node; error: Error } + +/** + * Queue data paired with an acknowledgement callback. + * The consumer must call {@link ack} after fully processing the data + * to unblock any producer waiting via {@link Queue.send}. + */ +export interface QueueEntry { + data: QueueData + ack: () => void +} + +/** + * Async queue with promise-based notification and optional back-pressure. + * + * Producers use {@link push} for fire-and-forget or {@link send} to + * block until the consumer has fully processed the data. The consumer calls + * {@link shift} to dequeue, then {@link QueueEntry.ack} after + * processing to unblock the producer. + */ +export class Queue { + private readonly _entries: QueueEntry[] = [] + /** Resolve function for the pending wait() promise, if any. */ + private _notify?: (() => void) | undefined + private _disposed = false + + /** + * Push data to the queue, waking any waiting consumer. + */ + push(data: QueueData): void { + this._entries.push({ data, ack: () => {} }) + this._notify?.() + this._notify = undefined + } + + /** + * Push data and wait until the consumer has fully processed it. + * Provides back-pressure so the producer pauses until the event + * has been yielded and hook callbacks have been invoked. + * + * @param data - The queue data to push + * @returns Promise that resolves when the consumer calls {@link QueueEntry.ack} + */ + send(data: QueueData): Promise { + if (this._disposed) return Promise.resolve() + + return new Promise((resolve) => { + this._entries.push({ data, ack: resolve }) + this._notify?.() + this._notify = undefined + }) + } + + /** + * Wait until at least one entry is available. + */ + wait(): Promise { + if (this._entries.length > 0) return Promise.resolve() + return new Promise((resolve) => { + this._notify = resolve + }) + } + + /** + * Remove and return the next entry, or undefined if empty. + */ + shift(): QueueEntry | undefined { + return this._entries.shift() + } + + /** + * Dispose the queue by resolving all pending acks and draining entries. + * Future {@link send} calls resolve immediately. + */ + dispose(): void { + this._disposed = true + while (this._entries.length > 0) { + this._entries.shift()!.ack() + } + } + + /** + * Number of entries in the queue. + */ + get size(): number { + return this._entries.length + } +} diff --git a/strands-ts/src/multiagent/snapshot.ts b/strands-ts/src/multiagent/snapshot.ts new file mode 100644 index 0000000000..bc927df454 --- /dev/null +++ b/strands-ts/src/multiagent/snapshot.ts @@ -0,0 +1,96 @@ +/** + * Snapshot implementation for multi-agent orchestrators (Graph and Swarm). + * + * Well-known keys in data: + * - `orchestratorId` — orchestrator identity for validation on load + * - `state` — serialized MultiAgentState (absent for nested orchestrators + * whose execution state is ephemeral) + */ + +import type { JSONValue } from '../types/json.js' +import { createTimestamp } from '../agent/snapshot.js' +import { SNAPSHOT_SCHEMA_VERSION } from '../types/snapshot.js' +import type { Snapshot } from '../types/snapshot.js' +import type { MultiAgentState } from './state.js' +import { serializeStateSerializable, loadStateSerializable } from '../types/serializable.js' +import type { Swarm } from './swarm.js' +import type { Graph } from './graph.js' + +/** + * Options for taking a multi-agent snapshot. + */ +export interface TakeMultiAgentSnapshotOptions { + /** Application-owned data. Strands does not read or modify this. */ + appData?: Record +} + +/** + * Takes a snapshot of a multi-agent orchestrator's current state. + * + * NOTE: This is currently an internal implementation detail. We anticipate + * exposing this as a public method in a future release after API review. + * + * @param orchestrator - The Graph or Swarm to snapshot + * @param state - The current execution state, or undefined for nested orchestrators + * whose state is ephemeral and not available from outside + * @param options - Multi-agent snapshot options + * @returns A snapshot of the orchestrator's state + */ +export function takeSnapshot( + orchestrator: Graph | Swarm, + state?: MultiAgentState, + options: TakeMultiAgentSnapshotOptions = {} +): Snapshot { + const data: Record = { + orchestratorId: orchestrator.id, + } + + if (state) { + data.state = serializeStateSerializable(state) + } + + return { + scope: 'multiAgent', + schemaVersion: SNAPSHOT_SCHEMA_VERSION, + createdAt: createTimestamp(), + data, + appData: options.appData ?? {}, + } +} + +/** + * Loads a multi-agent snapshot, restoring execution state. + * + * Follows the same mutate-in-place pattern as the agent snapshot: if a `state` + * instance is provided, execution state is loaded into it. Execution state is a + * separate parameter (rather than a field on the orchestrator) because orchestrators + * create ephemeral state per `stream()` call — there is no persistent state field + * to mutate. + * + * NOTE: This is currently an internal implementation detail. We anticipate + * exposing this as a public method in a future release after API review. + * + * @param orchestrator - The Graph or Swarm to restore into + * @param snapshot - The snapshot to load + * @param state - Optional MultiAgentState to restore execution state into + */ +export function loadSnapshot(orchestrator: Graph | Swarm, snapshot: Snapshot, state?: MultiAgentState): void { + if (snapshot.scope !== 'multiAgent') { + throw new Error(`Expected snapshot scope 'multiAgent', got '${snapshot.scope}'`) + } + if (snapshot.schemaVersion !== SNAPSHOT_SCHEMA_VERSION) { + throw new Error( + `Unsupported snapshot schema version: ${snapshot.schemaVersion}. Current version: ${SNAPSHOT_SCHEMA_VERSION}` + ) + } + + if (snapshot.data.orchestratorId !== orchestrator.id) { + throw new Error( + `Snapshot orchestrator ID mismatch: expected '${orchestrator.id}', got '${snapshot.data.orchestratorId}'` + ) + } + + if (state && 'state' in snapshot.data) { + loadStateSerializable(state, snapshot.data.state) + } +} diff --git a/strands-ts/src/multiagent/state.ts b/strands-ts/src/multiagent/state.ts new file mode 100644 index 0000000000..5ed5ba2f28 --- /dev/null +++ b/strands-ts/src/multiagent/state.ts @@ -0,0 +1,393 @@ +import { StateStore } from '../state-store.js' +import { type ContentBlock, contentBlockFromData } from '../types/messages.js' +import type { Usage } from '../models/streaming.js' +import { accumulateUsage, createEmptyUsage } from '../models/streaming.js' +import type { z } from 'zod' +import type { JSONValue } from '../types/json.js' +import { normalizeError, serializeError } from '../errors.js' +import { Interrupt } from '../interrupt.js' +import type { MultiAgentInput } from './multiagent.js' +import type { Snapshot } from '../types/snapshot.js' +import { + loadStateFromJSONSymbol, + stateToJSONSymbol, + serializeStateSerializable, + loadStateSerializable, + type StateSerializable, +} from '../types/serializable.js' + +/** + * Execution lifecycle status shared across all multi-agent patterns. + */ +export const Status = { + /** Execution has not yet started. */ + PENDING: 'PENDING', + /** Execution is currently in progress. */ + EXECUTING: 'EXECUTING', + /** Execution finished successfully. */ + COMPLETED: 'COMPLETED', + /** Execution encountered an error. */ + FAILED: 'FAILED', + /** Execution was cancelled before or during processing. */ + CANCELLED: 'CANCELLED', + /** Execution paused awaiting an interrupt response; can be resumed. */ + INTERRUPTED: 'INTERRUPTED', +} as const + +/** + * Union of all valid status values. + */ +export type Status = (typeof Status)[keyof typeof Status] + +/** + * Subset of {@link Status} valid for a {@link NodeResult}. + */ +export type ResultStatus = + | typeof Status.COMPLETED + | typeof Status.FAILED + | typeof Status.CANCELLED + | typeof Status.INTERRUPTED + +/** + * Result of executing a single node. + */ +export class NodeResult { + readonly type = 'nodeResult' as const + readonly nodeId: string + readonly status: ResultStatus + /** Execution time in milliseconds. */ + readonly duration: number + readonly content: ContentBlock[] + readonly error?: Error + /** Validated structured output, if a schema was provided. */ + readonly structuredOutput?: z.output + /** Token usage from the node execution. */ + readonly usage?: Usage + /** Interrupts raised by the underlying agent/orchestrator. Present iff `status === 'INTERRUPTED'`. */ + readonly interrupts?: Interrupt[] + + constructor(data: { + nodeId: string + status: ResultStatus + duration: number + content?: ContentBlock[] + error?: Error + structuredOutput?: z.output + usage?: Usage + interrupts?: Interrupt[] + }) { + this.nodeId = data.nodeId + this.status = data.status + this.duration = data.duration + this.content = data.content ?? [] + if ('error' in data) this.error = data.error + if ('structuredOutput' in data) this.structuredOutput = data.structuredOutput + if ('usage' in data) this.usage = data.usage + if (data.interrupts && data.interrupts.length > 0) this.interrupts = data.interrupts + } + + /** Serializes this result to a JSON-compatible value. */ + toJSON(): JSONValue { + return { + type: this.type, + nodeId: this.nodeId, + status: this.status, + duration: this.duration, + content: this.content.map((block) => block.toJSON()), + ...(this.error && { error: serializeError(this.error) }), + ...(this.structuredOutput !== undefined && { structuredOutput: this.structuredOutput as JSONValue }), + ...(this.usage && { usage: { ...this.usage } }), + ...(this.interrupts && { interrupts: this.interrupts.map((i) => i.toJSON()) }), + } as JSONValue + } + + /** Creates a NodeResult from a previously serialized JSON value. */ + static fromJSON(data: JSONValue): NodeResult { + const json = data as Record + return new NodeResult({ + nodeId: json.nodeId as string, + status: json.status as ResultStatus, + duration: json.duration as number, + content: (json.content as JSONValue[]).map((c) => contentBlockFromData(c as never)), + ...(json.error && { error: normalizeError(json.error) }), + ...(json.structuredOutput !== undefined && { structuredOutput: json.structuredOutput }), + ...(json.usage && { usage: json.usage as unknown as Usage }), + ...(json.interrupts && { + interrupts: (json.interrupts as JSONValue[]).map((i) => Interrupt.fromJSON(i as never)), + }), + }) + } +} + +/** + * Partial result returned by {@link Node.handle} implementations. + * + * Contains implementer-controlled fields that are merged with + * framework-managed defaults (nodeId, status, duration, content) to + * produce the final {@link NodeResult}. + */ +export type NodeResultUpdate = Partial> + +/** + * Execution state of a single node within a multi-agent orchestration. + */ +export class NodeState implements StateSerializable { + readonly type = 'nodeState' as const + status: Status + /** Whether this node is a terminal node — one where an execution path ended. */ + terminus: boolean + /** Node execution start time in milliseconds since epoch. */ + startTime: number + readonly results: NodeResult[] + /** Unanswered interrupts raised during this node's most recent run. Populated when `status === 'INTERRUPTED'`. */ + interrupts: Interrupt[] + /** + * Snapshot of the node's underlying runnable (Agent or nested orchestrator) captured + * when the node returned INTERRUPTED. Loaded back into the runnable on resume so it + * can pick up mid-execution without losing its interrupt bookkeeping. Cleared when + * the node completes. + */ + interruptedSnapshot?: Snapshot + + constructor() { + this.status = Status.PENDING + this.terminus = false + this.startTime = Date.now() + this.results = [] + this.interrupts = [] + } + + /** Content from the most recent result, or empty array if none. */ + get content(): readonly ContentBlock[] { + const last = this.results[this.results.length - 1] + return last?.content ?? [] + } + + /** Returns the serialized state as a JSON value. */ + [stateToJSONSymbol](): JSONValue { + return { + status: this.status, + terminus: this.terminus, + startTime: this.startTime, + results: this.results.map((res) => res.toJSON()), + interrupts: this.interrupts.map((i) => i.toJSON()), + ...(this.interruptedSnapshot && { interruptedSnapshot: { ...this.interruptedSnapshot } }), + } as JSONValue + } + + /** Loads state from a previously serialized JSON value. */ + [loadStateFromJSONSymbol](json: JSONValue): void { + const data = json as Record + this.status = data.status as Status + this.terminus = data.terminus as boolean + this.startTime = data.startTime as number + this.results.length = 0 + for (const entry of data.results as JSONValue[]) { + this.results.push(NodeResult.fromJSON(entry)) + } + this.interrupts = ((data.interrupts as JSONValue[] | undefined) ?? []).map((i) => Interrupt.fromJSON(i as never)) + if (data.interruptedSnapshot) { + this.interruptedSnapshot = data.interruptedSnapshot as unknown as Snapshot + } else { + delete this.interruptedSnapshot + } + } +} + +/** + * Aggregate result from a multi-agent execution. + */ +export class MultiAgentResult { + readonly type = 'multiAgentResult' as const + readonly status: ResultStatus + readonly results: NodeResult[] + /** Combined content from terminus nodes, in completion order. */ + readonly content: ContentBlock[] + readonly duration: number + readonly error?: Error + /** Aggregated token usage across all node results. */ + readonly usage: Usage + /** Interrupts aggregated across all node results. Present when any node ended INTERRUPTED. */ + readonly interrupts?: Interrupt[] + + constructor(data: { + status?: ResultStatus + results: NodeResult[] + content?: ContentBlock[] + duration: number + error?: Error + interrupts?: Interrupt[] + }) { + this.status = data.status ?? this._resolveStatus(data.results) + this.results = data.results + this.content = data.content ?? [] + this.duration = data.duration + if ('error' in data) this.error = data.error + this.usage = this._aggregateNodeUsage(data.results) + const interrupts = data.interrupts ?? data.results.flatMap((r) => r.interrupts ?? []) + if (interrupts.length > 0) this.interrupts = interrupts + } + + /** Serializes this result to a JSON-compatible value. */ + toJSON(): JSONValue { + return { + type: this.type, + status: this.status, + results: this.results.map((result) => result.toJSON()), + content: this.content.map((block) => block.toJSON()), + duration: this.duration, + usage: { ...this.usage }, + ...(this.error && { error: serializeError(this.error) }), + ...(this.interrupts && { interrupts: this.interrupts.map((i) => i.toJSON()) }), + } as JSONValue + } + + /** Creates a MultiAgentResult from a previously serialized JSON value. */ + static fromJSON(data: JSONValue): MultiAgentResult { + const json = data as Record + return new MultiAgentResult({ + status: json.status as ResultStatus, + results: (json.results as JSONValue[]).map(NodeResult.fromJSON), + content: (json.content as JSONValue[]).map((c) => contentBlockFromData(c as never)), + duration: json.duration as number, + ...(json.error && { error: normalizeError(json.error) }), + ...(json.interrupts && { + interrupts: (json.interrupts as JSONValue[]).map((i) => Interrupt.fromJSON(i as never)), + }), + }) + } + + /** + * Derives the aggregate status from individual node results. + * + * Precedence: FAILED \> INTERRUPTED \> CANCELLED \> COMPLETED. INTERRUPTED outranks + * CANCELLED because parallel-graph short-circuit aborts siblings as CANCELLED when + * one node interrupts — the actionable "resume me" signal should surface over the + * collateral cancellations. + */ + private _resolveStatus(results: NodeResult[]): ResultStatus { + if (results.some((result) => result.status === Status.FAILED)) return Status.FAILED + if (results.some((result) => result.status === Status.INTERRUPTED)) return Status.INTERRUPTED + if (results.some((result) => result.status === Status.CANCELLED)) return Status.CANCELLED + return Status.COMPLETED + } + + /** Sums token usage across all node results. */ + private _aggregateNodeUsage(results: NodeResult[]): Usage { + const usage = createEmptyUsage() + for (const result of results) { + if (!result.usage) continue + accumulateUsage(usage, result.usage) + } + return usage + } +} + +/** + * Rehydrates a serialized `_pendingInput` back to its runtime shape. `string` round-trips + * as-is; array inputs (which serialize as `ContentBlockData[]` via each block's `toJSON`) + * are mapped through `contentBlockFromData` so downstream callers see `ContentBlock[]` + * instead of raw data objects. + */ +function rehydratePendingInput(value: JSONValue): MultiAgentInput { + if (typeof value === 'string') return value + if (Array.isArray(value)) { + return (value as JSONValue[]).map((entry) => contentBlockFromData(entry as never)) as ContentBlock[] + } + // Unexpected shape — pass through so callers see the exact value and can diagnose. + return value as unknown as MultiAgentInput +} + +/** + * Per-execution state for multi-agent orchestration, created fresh each invocation. + */ +export class MultiAgentState implements StateSerializable { + /** Execution start time in milliseconds since epoch. */ + readonly startTime: number + /** Number of node executions started so far. */ + steps: number + /** All node results in completion order. */ + readonly results: NodeResult[] + /** App-level key-value state accessible from hooks, edge handlers, and custom nodes. */ + readonly app: StateStore + /** + * The invocation's input, carried through an interrupt pause so that resuming a + * run (on the same instance, or via a SessionManager) can re-enter nodes that + * never ran (hook-gated source/start nodes) with the original content. Cleared + * when the invocation terminates in any non-INTERRUPTED state. + * + * @internal — not part of the public state shape; orchestrator-owned. + */ + _pendingInput?: MultiAgentInput + private readonly _nodes: Map + + constructor(data?: { nodeIds?: string[] }) { + this.startTime = Date.now() + this.steps = 0 + this.results = [] + this.app = new StateStore() + this._nodes = new Map() + for (const id of data?.nodeIds ?? []) { + this._nodes.set(id, new NodeState()) + } + } + + /** + * Get the state of a specific node by ID. + * + * @param id - The node identifier + * @returns The node's state, or undefined if the node is not tracked + */ + node(id: string): NodeState | undefined { + return this._nodes.get(id) + } + + /** + * All tracked node states. + */ + get nodes(): ReadonlyMap { + return this._nodes + } + + /** Returns the serialized state as a JSON value. */ + [stateToJSONSymbol](): JSONValue { + const nodes: Record = {} + for (const [id, nodeState] of this._nodes) { + nodes[id] = serializeStateSerializable(nodeState) + } + return { + startTime: this.startTime, + steps: this.steps, + results: this.results.map((result) => result.toJSON()), + app: serializeStateSerializable(this.app), + nodes, + ...(this._pendingInput !== undefined && { _pendingInput: this._pendingInput as unknown as JSONValue }), + } as JSONValue + } + + /** Loads state from a previously serialized JSON value. */ + [loadStateFromJSONSymbol](json: JSONValue): void { + const data = json as Record + ;(this as { startTime: number }).startTime = data.startTime as number + this.steps = data.steps as number + this.results.length = 0 + for (const entry of data.results as JSONValue[]) { + this.results.push(NodeResult.fromJSON(entry)) + } + loadStateSerializable(this.app, data.app as JSONValue) + this._nodes.clear() + const nodes = data.nodes as Record | undefined + if (nodes) { + for (const [id, nodeData] of Object.entries(nodes)) { + const nodeState = new NodeState() + loadStateSerializable(nodeState, nodeData) + this._nodes.set(id, nodeState) + } + } + if (data._pendingInput !== undefined) { + this._pendingInput = rehydratePendingInput(data._pendingInput) + } else { + delete this._pendingInput + } + } +} diff --git a/strands-ts/src/multiagent/swarm.ts b/strands-ts/src/multiagent/swarm.ts new file mode 100644 index 0000000000..f025272ed6 --- /dev/null +++ b/strands-ts/src/multiagent/swarm.ts @@ -0,0 +1,650 @@ +import { logger } from '../logging/logger.js' +import { warnOnce } from '../logging/warn-once.js' +import type { AttributeValue, Span } from '@opentelemetry/api' +import type { InvocationState, InvokableAgent } from '../types/agent.js' +import type { MultiAgentInput, MultiAgentInvokeOptions } from './multiagent.js' +import { + applyOrchestratorHookResponses, + dropStaleInterruptedResult, + extractResumeResponses, + groupInterruptResponsesByNode, + recordHookInterrupt, +} from './multiagent.js' +import { InterruptError } from '../interrupt.js' +import { z } from 'zod' +import { HookableEvent } from '../hooks/events.js' +import { HookRegistryImplementation } from '../hooks/registry.js' +import type { HookCallback, HookableEventConstructor, HookCleanup } from '../hooks/types.js' +import type { MultiAgentPlugin } from './plugins.js' +import { MultiAgentPluginRegistry } from './plugins.js' +import type { SessionManager } from '../session/session-manager.js' +import type { ContentBlock } from '../types/messages.js' +import { TextBlock } from '../types/messages.js' +import type { AgentNodeOptions } from './nodes.js' +import { AgentNode } from './nodes.js' +import { MultiAgentState, MultiAgentResult, NodeResult, Status } from './state.js' +import type { MultiAgent } from './multiagent.js' +import type { MultiAgentStreamEvent } from './events.js' +import { + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentHandoffEvent, + MultiAgentInitializedEvent, + MultiAgentResultEvent, + NodeCancelEvent, + NodeResultEvent, +} from './events.js' +import { Tracer } from '../telemetry/tracer.js' +import { normalizeError } from '../errors.js' + +/** + * Runtime configuration for swarm execution. + */ +export interface SwarmConfig { + /** Max total agent executions (including start). Defaults to `Infinity` (no limit). */ + maxSteps?: number + /** + * Wall-clock ceiling for the entire swarm invocation, in milliseconds. Defaults to `Infinity` + * (no limit). Composed with each node's cancel signal, so a node that exceeds this bound + * mid-execution will be aborted (cooperatively). + */ + timeout?: number + /** + * Fallback per-node wall-clock ceiling in milliseconds. Applied to any node that doesn't + * set its own `timeout`. Defaults to `Infinity` (no limit). + * + * Enforced via `AbortSignal` — cancellation is cooperative, so a tool that neither polls + * its cancel signal nor forwards it to a cancellable API can run past this deadline. + */ + nodeTimeout?: number +} + +/** + * Structured output each agent produces to decide the next step. + * + * When `agentId` is provided, the swarm hands off to that agent with + * `message` as input. When omitted, `message` becomes the final response. + */ +interface HandoffResult { + /** Agent id to hand off to. Omit to end the swarm and return `message` as the final response. */ + agentId?: string + /** Instructions for the next agent, or the final response if no handoff. */ + message: string + /** Structured data to pass to the next agent. Serialized as a JSON text block alongside the handoff message. */ + context?: Record +} + +/** + * Input type for swarm nodes. Pass an {@link InvokableAgent} directly for the simple case, + * or {@link AgentNodeOptions} for per-node config. + */ +export type SwarmNodeDefinition = InvokableAgent | AgentNodeOptions + +export interface SwarmOptions extends SwarmConfig { + /** Unique identifier. Defaults to `'swarm'`. */ + id?: string + /** Swarm agents. Pass agents directly or use {@link AgentNodeOptions} for per-node config. */ + nodes: SwarmNodeDefinition[] + /** Agent id that receives the initial input. Defaults to the first agent in `nodes`. */ + start?: string + /** Session manager for saving and restoring swarm sessions. */ + sessionManager?: SessionManager + /** Plugins for event-driven extensibility. */ + plugins?: MultiAgentPlugin[] + /** Custom trace attributes to include on all spans. */ + traceAttributes?: Record +} + +/** + * Swarm multi-agent orchestration pattern. + * + * Agents execute sequentially, each deciding whether to hand off to another agent or + * produce a final response. Routing is driven by structured output: each agent receives + * a Zod schema with `agentId`, `message`, and optional `context` fields. When `agentId` + * is present, the swarm hands off to that agent with `message` as input. When omitted, + * `message` becomes the final response. + * + * Key design choices vs the Python SDK: + * - Handoffs use structured output rather than an injected `handoff_to_agent` tool. + * Routing logic stays in the orchestrator, not inside tool callbacks. + * - Context is passed as serialized JSON text blocks rather than a mutable SharedContext. + * - A single `maxSteps` limit replaces Python's separate `max_handoffs`/`max_iterations`. + * - Agent descriptions are embedded in the structured output schema for routing decisions. + * - Exceeding `maxSteps` throws an exception. Python returns a FAILED result. + * + * @example + * ```typescript + * const swarm = new Swarm({ + * nodes: [researcher, writer], + * start: 'researcher', + * maxSteps: 10, + * }) + * + * const result = await swarm.invoke('Explain quantum computing') + * ``` + */ +export class Swarm implements MultiAgent { + readonly id: string + readonly nodes: ReadonlyMap + readonly config: Required + private readonly _pluginRegistry: MultiAgentPluginRegistry + private readonly _hookRegistry: HookRegistryImplementation + private readonly _tracer: Tracer + readonly start: AgentNode + readonly sessionManager?: SessionManager | undefined + private _initialized: boolean + /** + * State retained across invocations when a run ends INTERRUPTED. Lets + * `swarm.invoke(responses)` resume on the same instance without requiring a + * SessionManager, mirroring single-agent ergonomics. Cleared when a run + * terminates in any non-INTERRUPTED state. + */ + private _pendingInterruptState?: MultiAgentState + + constructor(options: SwarmOptions) { + const { id, nodes, start, sessionManager, plugins, traceAttributes, ...config } = options + + this.id = id ?? 'swarm' + + this.config = { + maxSteps: config.maxSteps ?? Infinity, + timeout: config.timeout ?? Infinity, + nodeTimeout: config.nodeTimeout ?? Infinity, + } + this._validateConfig() + + if (this.config.maxSteps === Infinity && this.config.timeout === Infinity) { + warnOnce(logger, 'swarm has no maxSteps or timeout set; execution is unbounded') + } + + this.nodes = this._resolveNodes(nodes) + this.start = this._resolveStart(start) + + this.sessionManager = sessionManager + + if (sessionManager && plugins?.some((p) => p.name === sessionManager.name)) { + throw new Error('sessionManager was provided as both a constructor argument and in the plugins array') + } + + this._hookRegistry = new HookRegistryImplementation() + this._pluginRegistry = new MultiAgentPluginRegistry([ + ...(plugins ?? []), + ...(sessionManager ? [sessionManager] : []), + ]) + this._tracer = new Tracer(traceAttributes) + this._initialized = false + } + + /** + * Initialize the swarm. Invokes the {@link MultiAgentInitializedEvent} callback. + * Called automatically on first invocation. + */ + async initialize(): Promise { + if (this._initialized) return + await this._pluginRegistry.initialize(this) + await this._hookRegistry.invokeCallbacks(new MultiAgentInitializedEvent({ orchestrator: this })) + this._initialized = true + } + + /** + * Register a hook callback for a specific swarm event type. + * + * @param eventType - The event class constructor to register the callback for + * @param callback - The callback function to invoke when the event occurs + * @returns Cleanup function that removes the callback when invoked + */ + addHook(eventType: HookableEventConstructor, callback: HookCallback): HookCleanup { + return this._hookRegistry.addCallback(eventType, callback) + } + + /** + * Invoke swarm and return final result (consumes stream). + * + * @param input - The input to pass to the start agent + * @param options - Optional per-invocation options (e.g., {@link InvocationState}) + * @returns Promise resolving to the final MultiAgentResult + */ + async invoke(input: MultiAgentInput, options?: MultiAgentInvokeOptions): Promise { + const gen = this.stream(input, options) + let next = await gen.next() + while (!next.done) { + next = await gen.next() + } + return next.value + } + + /** + * Stream swarm execution, yielding events as agents execute. + * Invokes hook callbacks for each event before yielding. + * + * @param input - The input to pass to the start agent + * @param options - Optional per-invocation options (e.g., {@link InvocationState}) + * @returns Async generator yielding streaming events and returning a MultiAgentResult + */ + async *stream( + input: MultiAgentInput, + options?: MultiAgentInvokeOptions + ): AsyncGenerator { + await this.initialize() + + // Shared by reference across every node so mutations in one node's agent + // are visible to the next. + const invocationState: InvocationState = options?.invocationState ?? {} + + // Hook invocation lives in `_stream` so hook-raised `InterruptError`s land in the + // same frame as the execution loop. + const gen = this._stream(input, invocationState, options?.cancelSignal) + let next = await gen.next() + while (!next.done) { + yield next.value + next = await gen.next() + } + return next.value + } + + private async *_stream( + input: MultiAgentInput, + invocationState: InvocationState, + externalCancelSignal?: AbortSignal + ): AsyncGenerator { + // Reuse state from a prior INTERRUPTED run so `swarm.invoke(responses)` can + // resume on the same instance without a SessionManager. + const state = + this._pendingInterruptState ?? + new MultiAgentState({ + nodeIds: [...this.nodes.keys()], + }) + delete this._pendingInterruptState + + const multiAgentSpan = this._tracer.startMultiAgentSpan({ + orchestratorId: this.id, + orchestratorType: 'swarm', + input, + }) + + // SessionManager (or plugins) may restore state.results here via the hook + yield* this._emit(new BeforeMultiAgentInvocationEvent({ orchestrator: this, state, invocationState })) + + // Resume input bypasses handoff-derived resume (goes straight to the interrupted + // node). On fresh runs, stash the input for replay if a hook-gate pauses before + // the node runs. + const resumeResponses = extractResumeResponses(input) + const interruptResponsesByNode = resumeResponses ? groupInterruptResponsesByNode(resumeResponses, state) : undefined + if (!resumeResponses) { + state._pendingInput = input + } + + let node: AgentNode + let handoff: HandoffResult | undefined + let nextInput: MultiAgentInput = input + if (interruptResponsesByNode) { + // Swarm runs sequentially, so at most one node can be INTERRUPTED per run. + // Assert the invariant so a future change that accidentally produces multiple + // interrupted nodes surfaces loudly rather than silently taking the first. + if (interruptResponsesByNode.size > 1) { + throw new Error( + `swarm_id=<${this.id}>, interrupted_nodes=<${[...interruptResponsesByNode.keys()].join(',')}> | swarm cannot have multiple interrupted nodes simultaneously` + ) + } + const entry = interruptResponsesByNode.entries().next().value + if (!entry) throw new Error(`swarm_id=<${this.id}> | no interrupt responses to route`) + const [nodeId, responses] = entry + const resolvedNode = this.nodes.get(nodeId) + if (!resolvedNode) { + throw new Error( + `node_id=<${nodeId}>, swarm_id=<${this.id}> | resume response targets a node missing from the swarm; topology changed between save and resume?` + ) + } + node = resolvedNode + const resolvedNodeState = state.node(nodeId) + if (!resolvedNodeState) { + throw new Error( + `node_id=<${nodeId}>, swarm_id=<${this.id}> | routed interrupt response targets a node missing from state; topology changed between save and resume?` + ) + } + + // Orchestrator hooks consume matching responses; leftovers go to the child + // agent. If the hook consumed everything, replay the original invocation input. + const forwarded = applyOrchestratorHookResponses(resolvedNodeState, responses) + nextInput = forwarded.length > 0 ? forwarded : (state._pendingInput ?? '') + } else { + const resumeNode = this._findResumeNode(state) + node = resumeNode?.node ?? this.start + handoff = resumeNode?.lastHandoff + } + + let caughtError: Error | undefined + let result: MultiAgentResult | undefined + + // Swarm-level timeout composes with each node's signal so a hung node still gets + // aborted. Timer starts fresh per invocation; human response time between resumes + // is not deducted. + const execController = Number.isFinite(this.config.timeout) ? new AbortController() : undefined + const execTimeoutHandle = execController ? setTimeout(() => execController.abort(), this.config.timeout) : undefined + + const nodeCancelSignal = + execController && externalCancelSignal + ? AbortSignal.any([execController.signal, externalCancelSignal]) + : (execController?.signal ?? externalCancelSignal) + + try { + while (state.steps < this.config.maxSteps) { + if (execController?.signal.aborted) { + throw new Error(`timeout=<${this.config.timeout}>, swarm_id=<${this.id}> | swarm exceeded wall-clock budget`) + } + if (externalCancelSignal?.aborted) { + throw new Error(`swarm_id=<${this.id}> | swarm cancelled by external signal`) + } + state.steps++ + + // After the first step (which may use routed resume responses), revert to the + // original input so post-handoff nodes see fresh content. + const nodeResult = yield* this._streamNode( + node, + nextInput, + state, + handoff, + multiAgentSpan, + invocationState, + nodeCancelSignal + ) + nextInput = input + handoff = nodeResult.structuredOutput as HandoffResult | undefined + + if (execController?.signal.aborted) { + throw new Error( + `timeout=<${this.config.timeout}>, swarm_id=<${this.id}>, node_id=<${node.id}> | swarm exceeded wall-clock budget during node execution` + ) + } + + // Check for terminal conditions + if (nodeResult.status === Status.FAILED || nodeResult.status === Status.INTERRUPTED || !handoff?.agentId) { + break + } + + // Hand off to next agent + const target = this.nodes.get(handoff.agentId)! + yield* this._emit(new MultiAgentHandoffEvent({ source: node.id, targets: [target.id], state, invocationState })) + logger.debug(`source=<${node.id}>, target=<${target.id}> | swarm handoff`) + node = target + } + + this._checkSteps(state, handoff) + + result = new MultiAgentResult({ + results: state.results, + content: this._resolveContent(state), + duration: Date.now() - state.startTime, + }) + // Stash on interrupt so same-instance resume has state; otherwise start fresh. + if (result.status === Status.INTERRUPTED) { + this._pendingInterruptState = state + } else { + delete this._pendingInterruptState + delete state._pendingInput + } + } catch (error) { + caughtError = normalizeError(error) + throw caughtError + } finally { + if (execTimeoutHandle !== undefined) clearTimeout(execTimeoutHandle) + this._tracer.endMultiAgentSpan(multiAgentSpan, { + duration: Date.now() - state.startTime, + ...(result && { usage: result.usage }), + ...(caughtError && { error: caughtError }), + }) + + yield* this._emit(new AfterMultiAgentInvocationEvent({ orchestrator: this, state, invocationState })) + } + + yield* this._emit(new MultiAgentResultEvent({ result, invocationState })) + return result + } + + /** Invokes hook callbacks on an event, then yields it. */ + private async *_emit(event: T): AsyncGenerator { + await this._hookRegistry.invokeCallbacks(event) + yield event + } + + private async *_streamNode( + node: AgentNode, + input: MultiAgentInput, + state: MultiAgentState, + handoff: HandoffResult | undefined, + multiAgentSpan: Span | null, + invocationState: InvocationState, + executionSignal?: AbortSignal + ): AsyncGenerator { + const nodeState = state.node(node.id)! + const handoffSchema = this._buildHandoffSchema(node.id) + const nodeSpan = this._tracer.withSpanContext(multiAgentSpan, () => + this._tracer.startNodeSpan({ nodeId: node.id, nodeType: node.type }) + ) + + const beforeEvent = new BeforeNodeCallEvent({ orchestrator: this, state, nodeId: node.id, invocationState }) + try { + await this._hookRegistry.invokeCallbacks(beforeEvent) + } catch (error) { + if (error instanceof InterruptError) { + const result = recordHookInterrupt(node.id, nodeState) + state.results.push(result) + yield beforeEvent + yield* this._emit(new NodeResultEvent({ nodeId: node.id, nodeType: node.type, state, result, invocationState })) + yield* this._emit(new AfterNodeCallEvent({ orchestrator: this, state, nodeId: node.id, invocationState })) + this._tracer.endNodeSpan(nodeSpan, { status: Status.INTERRUPTED, duration: result.duration }) + return result + } + throw error + } + yield beforeEvent + + if (beforeEvent.cancel) { + const message = typeof beforeEvent.cancel === 'string' ? beforeEvent.cancel : 'node cancelled by hook' + // Cancel path doesn't go through Node.stream, so do its INTERRUPTED cleanup here. + dropStaleInterruptedResult(node.id, nodeState, state) + const result = new NodeResult({ nodeId: node.id, status: Status.CANCELLED, duration: 0 }) + nodeState.status = Status.CANCELLED + nodeState.results.push(result) + state.results.push(result) + yield* this._emit(new NodeCancelEvent({ nodeId: node.id, state, message, invocationState })) + yield* this._emit(new AfterNodeCallEvent({ orchestrator: this, state, nodeId: node.id, invocationState })) + this._tracer.endNodeSpan(nodeSpan, { status: Status.CANCELLED, duration: 0 }) + return result + } + + const nodeInput = this._resolveNodeInput(input, handoff) + + const nodeTimeout = node.timeout ?? this.config.nodeTimeout + const timeoutController = Number.isFinite(nodeTimeout) ? new AbortController() : undefined + const timeoutHandle = timeoutController ? setTimeout(() => timeoutController.abort(), nodeTimeout) : undefined + const signals = [executionSignal, timeoutController?.signal].filter((s): s is AbortSignal => s !== undefined) + const cancelSignal = signals.length > 0 ? AbortSignal.any(signals) : undefined + + try { + const gen = this._tracer.withSpanContext(nodeSpan, () => + node.stream(nodeInput, state, { + structuredOutputSchema: handoffSchema, + invocationState, + ...(cancelSignal && { cancelSignal }), + }) + ) + let next = await this._tracer.withSpanContext(nodeSpan, () => gen.next()) + while (!next.done) { + if (next.value instanceof HookableEvent) { + yield* this._emit(next.value) + } else { + yield next.value + } + next = await this._tracer.withSpanContext(nodeSpan, () => gen.next()) + } + + if (timeoutController?.signal.aborted) { + throw new Error( + `node_timeout=<${nodeTimeout}>, node_id=<${node.id}>, swarm_id=<${this.id}> | node exceeded wall-clock budget` + ) + } + + const result = next.value + this._tracer.endNodeSpan(nodeSpan, { status: result.status, duration: result.duration, usage: result.usage }) + state.results.push(result) + + yield* this._emit(new AfterNodeCallEvent({ orchestrator: this, state, nodeId: node.id, invocationState })) + return result + } catch (error) { + const nodeError = normalizeError(error) + this._tracer.endNodeSpan(nodeSpan, { error: nodeError }) + + yield* this._emit( + new AfterNodeCallEvent({ orchestrator: this, state, nodeId: node.id, invocationState, error: nodeError }) + ) + throw nodeError + } finally { + if (timeoutHandle !== undefined) clearTimeout(timeoutHandle) + } + } + + private _validateConfig(): void { + if (this.config.maxSteps < 1) { + throw new Error(`max_steps=<${this.config.maxSteps}> | must be at least 1`) + } + if (this.config.timeout < 1) { + throw new Error(`timeout=<${this.config.timeout}> | must be at least 1`) + } + if (this.config.nodeTimeout < 1) { + throw new Error(`node_timeout=<${this.config.nodeTimeout}> | must be at least 1`) + } + } + + private _resolveNodes(definitions: SwarmNodeDefinition[]): Map { + if (definitions.length === 0) { + throw new Error('nodes list is empty') + } + + const nodes = new Map() + for (const definition of definitions) { + const node = 'agent' in definition ? new AgentNode(definition) : new AgentNode({ agent: definition }) + if (nodes.has(node.id)) { + throw new Error(`agent_id=<${node.id}> | duplicate agent id`) + } + nodes.set(node.id, node) + } + return nodes + } + + private _resolveStart(start: string | undefined): AgentNode { + if (start === undefined) { + return this.nodes.values().next().value! + } + + const node = this.nodes.get(start) + if (!node) { + throw new Error(`start=<${start}> | start references unknown agent`) + } + return node + } + + private _resolveContent(state: MultiAgentState): ContentBlock[] { + const last = state.results[state.results.length - 1]! + state.node(last.nodeId)!.terminus = true + + const handoff = last.structuredOutput as HandoffResult | undefined + if (handoff?.message) { + return [new TextBlock(handoff.message)] + } + + return [...last.content] + } + + /** + * Builds the input for the next node after a handoff, or returns the input as-is + * when there is no handoff (initial or resume invocation). The caller passes the + * original `MultiAgentInput` through; resume responses flow through here untouched + * so the underlying agent sees them directly. + */ + private _resolveNodeInput(input: MultiAgentInput, handoff?: HandoffResult): MultiAgentInput { + if (!handoff) return input + + const blocks: ContentBlock[] = [new TextBlock(handoff.message)] + if (handoff.context) { + blocks.push(new TextBlock('Context:\n' + JSON.stringify(handoff.context, null, 2))) + } + return blocks + } + + /** + * Checks whether the swarm has exceeded its step limit with work still pending. + * + * This is only an error when the loop exhausted its step budget while the last agent + * still requested a handoff (i.e. there was more work to do). If the swarm completed + * normally on its final allowed step (no pending handoff), no error is thrown. + * + * @param state - Current swarm execution state + * @param handoff - The last handoff result from the most recent agent execution + * @throws Error when step limit is reached with a pending handoff + */ + private _checkSteps(state: MultiAgentState, handoff?: HandoffResult): void { + if (handoff?.agentId && state.steps >= this.config.maxSteps) { + throw new Error(`max_steps=<${this.config.maxSteps}> | swarm reached step limit`) + } + } + + /** + * Finds the next node to execute from a restored {@link MultiAgentState}. + * + * When the session manager restores state from a snapshot, `state.results` + * contains results from the previous invocation in completion order. The last + * result's structured output contains the handoff decision — if it has an + * `agentId`, that is the node the previous run intended to hand off to but + * never executed (e.g. due to a crash). We resume from that handoff target. + * + * If the last result has no `agentId`, the previous run completed normally + * and there is nothing to resume. + * + * @returns The handoff target node and its handoff context, or `undefined` for a fresh start + */ + private _findResumeNode(state: MultiAgentState): { node: AgentNode; lastHandoff: HandoffResult } | undefined { + const lastResult = state.results[state.results.length - 1] + if (!lastResult) return undefined + + const lastNodeHandoff = lastResult.structuredOutput as HandoffResult | undefined + if (!lastNodeHandoff?.agentId) return undefined + + const nextNode = this.nodes.get(lastNodeHandoff.agentId) + if (!nextNode) { + logger.warn(`node_id=<${lastNodeHandoff.agentId}> | resume target not found in swarm, starting fresh`) + return undefined + } + + logger.debug(`node_id=<${nextNode.id}>, prior_steps=<${state.steps}> | resuming swarm from restored state`) + return { node: nextNode, lastHandoff: lastNodeHandoff } + } + + private _buildHandoffSchema(nodeId: string): z.ZodType { + const handoffIds = [...this.nodes.keys()].filter((id) => id !== nodeId) + const handoffDescriptions = handoffIds + .map((id) => { + const desc = this.nodes.get(id)!.config.description + return desc ? `- ${id}: ${desc}` : `- ${id}` + }) + .join('\n') + + return z + .object({ + agentId: + handoffIds.length > 0 + ? z + .enum(handoffIds as [string, ...string[]]) + .optional() + .describe( + `Target agent to hand off to. Omit to end the conversation.\n\nAvailable agents:\n${handoffDescriptions}` + ) + : z.never().optional().describe('No other agents available. Omit this field to end the conversation.'), + message: z.string().describe('Instructions for the next agent, or the final response if no handoff.'), + context: z.record(z.string(), z.unknown()).optional().describe('Structured data to pass to the next agent.'), + }) + .describe('Decide whether to hand off to another agent or produce a final response.') as z.ZodType + } +} diff --git a/strands-ts/src/plugins/__tests__/plugin.test.ts b/strands-ts/src/plugins/__tests__/plugin.test.ts new file mode 100644 index 0000000000..98312907aa --- /dev/null +++ b/strands-ts/src/plugins/__tests__/plugin.test.ts @@ -0,0 +1,146 @@ +import { describe, it, expect } from 'vitest' +import type { Plugin } from '../plugin.js' +import { BeforeInvocationEvent, type HookableEvent } from '../../hooks/events.js' +import { ToolRegistry } from '../../registry/tool-registry.js' +import type { HookableEventConstructor, HookCallback, HookCleanup } from '../../hooks/types.js' +import type { LocalAgent } from '../../types/agent.js' +import { createRandomTool } from '../../__fixtures__/tool-helpers.js' + +/** + * Concrete implementation of Plugin for testing purposes. + */ +class TestPlugin implements Plugin { + callbacks: Array<{ eventType: unknown; callback: unknown }> = [] + + get name(): string { + return 'test-plugin' + } + + initAgent(agent: LocalAgent): void { + agent.addHook(BeforeInvocationEvent, () => { + // No-op for testing + }) + } +} + +/** + * Plugin with custom name for testing. + */ +class CustomNamePlugin implements Plugin { + private readonly _name: string + + constructor(name: string) { + this._name = name + } + + get name(): string { + return this._name + } + + initAgent(_agent: LocalAgent): void {} +} + +/** + * Plugin with initAgent implementation for testing. + */ +class InitializablePlugin implements Plugin { + public initialized = false + + get name(): string { + return 'initializable-plugin' + } + + initAgent(_agent: LocalAgent): void { + this.initialized = true + } +} + +describe('Plugin', () => { + describe('name', () => { + it('returns the plugin name', () => { + const plugin = new TestPlugin() + expect(plugin.name).toBe('test-plugin') + }) + + it('supports custom names via constructor', () => { + const plugin = new CustomNamePlugin('my-custom-plugin') + expect(plugin.name).toBe('my-custom-plugin') + }) + }) + + describe('initAgent', () => { + it('registers callbacks via agent.addHook', () => { + const plugin = new TestPlugin() + const callbacks: Array<{ + eventType: HookableEventConstructor + callback: HookCallback + }> = [] + const mockAgent = { + addHook: ( + eventType: HookableEventConstructor, + callback: HookCallback + ): HookCleanup => { + callbacks.push({ + eventType: eventType as HookableEventConstructor, + callback: callback as HookCallback, + }) + return () => {} + }, + toolRegistry: new ToolRegistry(), + } as unknown as LocalAgent + + plugin.initAgent(mockAgent) + + expect(callbacks).toHaveLength(1) + expect(callbacks[0]?.eventType).toBe(BeforeInvocationEvent) + }) + + it('has a no-op default when not overridden', () => { + const plugin: Plugin = new CustomNamePlugin('test') + const mockAgent = { + addHook: () => () => {}, + toolRegistry: new ToolRegistry(), + } as unknown as LocalAgent + + // Should not throw and return undefined + const result = plugin.initAgent(mockAgent) + expect(result).toBeUndefined() + }) + + it('can be implemented for custom initialization', () => { + const plugin = new InitializablePlugin() + const mockAgent = { + addHook: () => () => {}, + toolRegistry: new ToolRegistry(), + } as unknown as LocalAgent + + expect(plugin.initialized).toBe(false) + + plugin.initAgent(mockAgent) + + expect(plugin.initialized).toBe(true) + }) + }) + + describe('getTools', () => { + it('is optional — plugins without getTools are valid', () => { + const plugin: Plugin = new TestPlugin() + expect(plugin.getTools).toBeUndefined() + }) + + it('can be implemented to provide tools', () => { + const mockTool = createRandomTool() + class ToolPlugin implements Plugin { + get name(): string { + return 'tool-plugin' + } + initAgent(_agent: LocalAgent): void {} + getTools() { + return [mockTool] + } + } + + expect(new ToolPlugin().getTools()).toStrictEqual([mockTool]) + }) + }) +}) diff --git a/strands-ts/src/plugins/__tests__/registry.test.ts b/strands-ts/src/plugins/__tests__/registry.test.ts new file mode 100644 index 0000000000..c92a8b3531 --- /dev/null +++ b/strands-ts/src/plugins/__tests__/registry.test.ts @@ -0,0 +1,190 @@ +import { describe, it, expect, beforeEach, vi } from 'vitest' +import { PluginRegistry } from '../registry.js' +import type { Plugin } from '../plugin.js' +import { BeforeInvocationEvent, type HookableEvent } from '../../hooks/events.js' +import type { Tool } from '../../tools/tool.js' +import type { HookableEventConstructor, HookCallback } from '../../hooks/types.js' +import type { LocalAgent } from '../../types/agent.js' +import { createMockAgent } from '../../__fixtures__/agent-helpers.js' +import { createRandomTool } from '../../__fixtures__/tool-helpers.js' + +/** + * Test plugin implementation. + */ +class TestPlugin implements Plugin { + public hookRegistered = false + private readonly _name: string + + constructor(name: string = 'test-plugin') { + this._name = name + } + + get name(): string { + return this._name + } + + initAgent(agent: LocalAgent): void { + agent.addHook(BeforeInvocationEvent, () => { + this.hookRegistered = true + }) + } +} + +/** + * Plugin with initAgent for testing initialization. + */ +class InitializableTestPlugin implements Plugin { + public initialized = false + + constructor(private readonly _name: string = 'initializable-plugin') {} + + get name(): string { + return this._name + } + + initAgent(_agent: LocalAgent): void { + this.initialized = true + } +} + +/** + * Plugin that provides tools. + */ +class ToolProviderPlugin implements Plugin { + constructor( + private readonly _name: string, + private readonly _tools: Tool[] + ) {} + + get name(): string { + return this._name + } + + initAgent(_agent: LocalAgent): void {} + + getTools(): Tool[] { + return this._tools + } +} + +describe('PluginRegistry', () => { + let registry: PluginRegistry + let mockAgent: LocalAgent + let registeredHooks: Array<{ + eventType: HookableEventConstructor + callback: HookCallback + }> + + beforeEach(() => { + registeredHooks = [] + mockAgent = createMockAgent({ + extra: { + addHook: (eventType: HookableEventConstructor, callback: HookCallback) => { + registeredHooks.push({ + eventType: eventType as HookableEventConstructor, + callback: callback as HookCallback, + }) + return () => {} + }, + }, + }) as unknown as LocalAgent + }) + + describe('initialize', () => { + it('initializes a plugin and calls initAgent', async () => { + const plugin = new InitializableTestPlugin() + registry = new PluginRegistry([plugin]) + + await registry.initialize(mockAgent) + + expect(plugin.initialized).toBe(true) + }) + + it('registers hooks via agent.addHook', async () => { + const plugin = new TestPlugin() + registry = new PluginRegistry([plugin]) + + await registry.initialize(mockAgent) + + expect(registeredHooks).toHaveLength(1) + expect(registeredHooks[0]?.eventType).toBe(BeforeInvocationEvent) + }) + + it('throws error when plugins have duplicate names', async () => { + const plugin1 = new TestPlugin('duplicate-name') + const plugin2 = new TestPlugin('duplicate-name') + registry = new PluginRegistry([plugin1, plugin2]) + + await expect(registry.initialize(mockAgent)).rejects.toThrow( + 'plugin_name= | plugin already registered' + ) + }) + + it('initializes multiple plugins with different names', async () => { + const plugin1 = new TestPlugin('plugin-1') + const plugin2 = new TestPlugin('plugin-2') + registry = new PluginRegistry([plugin1, plugin2]) + + await registry.initialize(mockAgent) + + expect(registeredHooks).toHaveLength(2) + }) + + it('auto-registers tools from plugin.getTools()', async () => { + const mockTool = createRandomTool('mock-tool') + const plugin = new ToolProviderPlugin('tool-provider', [mockTool]) + registry = new PluginRegistry([plugin]) + + await registry.initialize(mockAgent) + + expect(mockAgent.toolRegistry.get(mockTool.name)).toBe(mockTool) + }) + + it('handles async initAgent', async () => { + class AsyncPlugin implements Plugin { + public initialized = false + + get name(): string { + return 'async-plugin' + } + + async initAgent(_agent: LocalAgent): Promise { + await vi.waitFor(() => Promise.resolve()) + this.initialized = true + } + } + + const plugin = new AsyncPlugin() + registry = new PluginRegistry([plugin]) + + await registry.initialize(mockAgent) + + expect(plugin.initialized).toBe(true) + }) + + it('is idempotent — calling initialize twice only runs plugins once', async () => { + const plugin = new InitializableTestPlugin() + registry = new PluginRegistry([plugin]) + + await registry.initialize(mockAgent) + plugin.initialized = false // reset to detect a second call + await registry.initialize(mockAgent) + + expect(plugin.initialized).toBe(false) + }) + }) + + describe('hook invocation', () => { + it('hooks are invoked when callbacks are called', async () => { + const plugin = new TestPlugin() + registry = new PluginRegistry([plugin]) + await registry.initialize(mockAgent) + + const callback = registeredHooks[0]?.callback + const mockAgentData = {} as LocalAgent + callback?.(new BeforeInvocationEvent({ agent: mockAgentData, invocationState: {} })) + + expect(plugin.hookRegistered).toBe(true) + }) + }) +}) diff --git a/strands-ts/src/plugins/index.ts b/strands-ts/src/plugins/index.ts new file mode 100644 index 0000000000..0ad51ba31a --- /dev/null +++ b/strands-ts/src/plugins/index.ts @@ -0,0 +1,30 @@ +/** + * Plugin system for extending agent functionality. + * + * This module provides the Plugin base class for extending agent behavior + * through hook callbacks and tool registration. + * + * @example + * ```typescript + * import { Plugin, BeforeInvocationEvent } from '@strands-agents/sdk' + * + * class MyPlugin extends Plugin { + * get name(): string { + * return 'my-plugin' + * } + * + * override initAgent(agent: LocalAgent): void { + * agent.addHook(BeforeInvocationEvent, (event) => { + * console.log('Before invocation') + * }) + * } + * } + * + * const agent = new Agent({ + * model, + * plugins: [new MyPlugin()], + * }) + * ``` + */ + +export type { Plugin } from './plugin.js' diff --git a/strands-ts/src/plugins/model-plugin.ts b/strands-ts/src/plugins/model-plugin.ts new file mode 100644 index 0000000000..e8d715660a --- /dev/null +++ b/strands-ts/src/plugins/model-plugin.ts @@ -0,0 +1,34 @@ +import { AfterInvocationEvent } from '../hooks/events.js' +import { logger } from '../logging/logger.js' +import type { Model } from '../models/model.js' +import type { LocalAgent } from '../types/agent.js' +import type { Plugin } from './plugin.js' + +/** + * Built-in plugin that manages model-related lifecycle hooks. + * + * When the model is stateful (server-managed conversation state), this plugin + * clears the agent's local message history after each invocation since the + * server holds the authoritative conversation state. + * + * Internal: wired up automatically by Agent; not re-exported from the package + * entrypoint and not intended to be instantiated by consumers. + */ +export class ModelPlugin implements Plugin { + readonly name = 'strands:model' + private readonly _model: Model + + constructor(model: Model) { + this._model = model + } + + initAgent(agent: LocalAgent): void { + const model = this._model + agent.addHook(AfterInvocationEvent, () => { + if (model.stateful) { + agent.messages.length = 0 + logger.debug('cleared messages for server-managed conversation') + } + }) + } +} diff --git a/strands-ts/src/plugins/plugin.ts b/strands-ts/src/plugins/plugin.ts new file mode 100644 index 0000000000..b270875ecb --- /dev/null +++ b/strands-ts/src/plugins/plugin.ts @@ -0,0 +1,77 @@ +/** + * Plugin interface for extending agent functionality. + * + * This module defines the Plugin interface, which provides a composable way to + * add behavior changes to agents through hook registration and custom initialization. + */ + +import type { Tool } from '../tools/tool.js' +import type { LocalAgent } from '../types/agent.js' + +/** + * Interface for objects that extend agent functionality. + * + * Plugins provide a composable way to add behavior changes to agents by registering + * hook callbacks in their `initAgent` method. Each plugin must have a unique name + * for identification, logging, and duplicate prevention. + * + * @example + * ```typescript + * class LoggingPlugin implements Plugin { + * get name(): string { + * return 'logging-plugin' + * } + * + * initAgent(agent: LocalAgent): void { + * agent.addHook(BeforeInvocationEvent, (event) => { + * console.log('Agent invocation started') + * }) + * } + * } + * + * const agent = new Agent({ + * model, + * plugins: [new LoggingPlugin()], + * }) + * ``` + * + * @example With tools + * ```typescript + * class MyToolPlugin implements Plugin { + * get name(): string { + * return 'my-tool-plugin' + * } + * + * getTools(): Tool[] { + * return [myTool] + * } + * } + * ``` + */ +export interface Plugin { + /** + * A stable string identifier for the plugin. + * Used for logging, duplicate detection, and plugin management. + * + * For strands-vended plugins, names should be prefixed with `strands:`. + */ + readonly name: string + + /** + * Initialize the plugin with the agent instance. + * + * Implement this method to register hooks and perform custom initialization. + * Tool registration from {@link getTools} is handled automatically by the PluginRegistry. + * + * @param agent - The agent instance this plugin is being attached to + */ + initAgent(agent: LocalAgent): void | Promise + + /** + * Returns tools provided by this plugin for auto-registration. + * Implement to provide plugin-specific tools. + * + * @returns Array of tools to register with the agent + */ + getTools?(): Tool[] +} diff --git a/strands-ts/src/plugins/registry.ts b/strands-ts/src/plugins/registry.ts new file mode 100644 index 0000000000..ea6c3c79cf --- /dev/null +++ b/strands-ts/src/plugins/registry.ts @@ -0,0 +1,49 @@ +/** + * Plugin registry for managing plugins attached to an agent. + */ + +import type { Plugin } from './plugin.js' +import type { LocalAgent } from '../types/agent.js' + +/** + * Registry for managing plugins attached to an agent. + * + * Holds pending plugins and initializes them on first use. + * Handles duplicate detection, tool registration, and calls each plugin's initAgent method. + */ +export class PluginRegistry { + private readonly _plugins: Map + private readonly _pending: Plugin[] + + constructor(plugins: Plugin[] = []) { + this._plugins = new Map() + this._pending = [...plugins] + } + + /** + * Initialize all pending plugins with the agent. + * Safe to call multiple times — only runs once per pending batch. + * + * @param agent - The agent instance to initialize plugins with + */ + async initialize(agent: LocalAgent): Promise { + while (this._pending.length > 0) { + const plugin = this._pending.shift()! + await this._addAndInit(plugin, agent) + } + } + + private async _addAndInit(plugin: Plugin, agent: LocalAgent): Promise { + if (this._plugins.has(plugin.name)) { + throw new Error(`plugin_name=<${plugin.name}> | plugin already registered`) + } + this._plugins.set(plugin.name, plugin) + + const tools = plugin.getTools?.() ?? [] + if (tools.length > 0) { + agent.toolRegistry.add(tools) + } + + await plugin.initAgent(agent) + } +} diff --git a/strands-ts/src/registry/__tests__/tool-registry.test.ts b/strands-ts/src/registry/__tests__/tool-registry.test.ts new file mode 100644 index 0000000000..614a699e29 --- /dev/null +++ b/strands-ts/src/registry/__tests__/tool-registry.test.ts @@ -0,0 +1,259 @@ +import { describe, it, expect, beforeEach } from 'vitest' +import { ToolRegistry } from '../tool-registry.js' +import { ToolNotFoundError, ToolValidationError } from '../../errors.js' +import type { Tool, ToolStreamGenerator } from '../../tools/tool.js' +import { ToolStreamEvent } from '../../tools/tool.js' +import { ToolResultBlock } from '../../types/messages.js' + +const createMockTool = (overrides: Partial = {}): Tool => ({ + name: 'valid-tool', + description: 'A valid tool description.', + toolSpec: { + name: 'valid-tool', + description: 'A valid tool description.', + inputSchema: { type: 'object', properties: {} }, + }, + stream: async function* (): ToolStreamGenerator { + yield new ToolStreamEvent({ data: 'mock data' }) + return new ToolResultBlock({ toolUseId: '', status: 'success', content: [] }) + }, + ...overrides, +}) + +describe('ToolRegistry', () => { + let registry: ToolRegistry + + beforeEach(() => { + registry = new ToolRegistry() + }) + + describe('add', () => { + it('registers a single tool', () => { + const tool = createMockTool() + registry.add(tool) + expect(registry.list()).toStrictEqual([tool]) + }) + + it('registers an array of tools', () => { + const tool1 = createMockTool({ name: 'tool-1' }) + const tool2 = createMockTool({ name: 'tool-2' }) + registry.add([tool1, tool2]) + expect(registry.list()).toStrictEqual([tool1, tool2]) + }) + + it('throws ToolValidationError for a duplicate tool name', () => { + registry.add(createMockTool({ name: 'duplicate' })) + expect(() => registry.add(createMockTool({ name: 'duplicate' }))).toThrow(ToolValidationError) + expect(() => registry.add(createMockTool({ name: 'duplicate' }))).toThrow( + "Tool with name 'duplicate' already registered" + ) + }) + + it("throws ToolValidationError when a name differs only by '-' vs '_'", () => { + registry.add(createMockTool({ name: 'foo-bar' })) + expect(() => registry.add(createMockTool({ name: 'foo_bar' }))).toThrow(ToolValidationError) + expect(() => registry.add(createMockTool({ name: 'foo_bar' }))).toThrow( + "Tool name 'foo_bar' already exists as 'foo-bar'. Cannot add a duplicate tool which differs by a '-' or '_'" + ) + }) + + it('throws ToolValidationError for an invalid tool name pattern', () => { + expect(() => registry.add(createMockTool({ name: 'invalid name!' }))).toThrow(ToolValidationError) + expect(() => registry.add(createMockTool({ name: 'invalid name!' }))).toThrow( + 'Tool name must contain only alphanumeric characters, hyphens, and underscores' + ) + }) + + it('throws ToolValidationError for a tool name that is too long', () => { + expect(() => registry.add(createMockTool({ name: 'a'.repeat(65) }))).toThrow(ToolValidationError) + expect(() => registry.add(createMockTool({ name: 'a'.repeat(65) }))).toThrow( + 'Tool name must be between 1 and 64 characters' + ) + }) + + it('throws ToolValidationError for a tool name that is too short', () => { + expect(() => registry.add(createMockTool({ name: '' }))).toThrow(ToolValidationError) + expect(() => registry.add(createMockTool({ name: '' }))).toThrow('Tool name must be between 1 and 64 characters') + }) + + it('throws ToolValidationError for a non-string tool name', () => { + // @ts-expect-error - Testing invalid type for name + expect(() => registry.add(createMockTool({ name: 123 }))).toThrow(ToolValidationError) + // @ts-expect-error - Testing invalid type for name + expect(() => registry.add(createMockTool({ name: 123 }))).toThrow('Tool name must be a string') + }) + + it('throws ToolValidationError for an invalid description', () => { + // @ts-expect-error - Testing invalid type for description + expect(() => registry.add(createMockTool({ description: 123 }))).toThrow(ToolValidationError) + // @ts-expect-error - Testing invalid type for description + expect(() => registry.add(createMockTool({ description: 123 }))).toThrow( + 'Tool description must be a non-empty string' + ) + }) + + it('throws ToolValidationError for an empty string description', () => { + expect(() => registry.add(createMockTool({ description: '' }))).toThrow(ToolValidationError) + expect(() => registry.add(createMockTool({ description: '' }))).toThrow( + 'Tool description must be a non-empty string' + ) + }) + + it('allows a tool with a null or undefined description', () => { + const tool1 = createMockTool({ name: 'tool-1' }) + // @ts-expect-error - Testing explicit undefined description + tool1.description = undefined + + const tool2 = createMockTool({ name: 'tool-2' }) + // @ts-expect-error - Testing explicit null description + tool2.description = null + + registry.add([tool1, tool2]) + expect(registry.list()).toHaveLength(2) + }) + + it('registers a tool with a name at the maximum length', () => { + const tool = createMockTool({ name: 'a'.repeat(64) }) + expect(() => registry.add(tool)).not.toThrow() + }) + }) + + describe('addOrReplace', () => { + it('registers tools', () => { + const tool = createMockTool({ name: 'tool-1' }) + registry.addOrReplace([tool]) + expect(registry.get('tool-1')).toBe(tool) + }) + + it('replaces an existing tool with the same name', () => { + const original = createMockTool({ name: 'tool-1', description: 'original' }) + const replacement = createMockTool({ name: 'tool-1', description: 'replacement' }) + registry.add(original) + registry.addOrReplace([replacement]) + expect(registry.get('tool-1')).toBe(replacement) + }) + + it('validates tool properties', () => { + expect(() => registry.addOrReplace([createMockTool({ name: 'invalid name!' })])).toThrow(ToolValidationError) + }) + + it("throws ToolValidationError when a new tool name differs only by '-' vs '_'", () => { + registry.add(createMockTool({ name: 'foo-bar' })) + expect(() => registry.addOrReplace([createMockTool({ name: 'foo_bar' })])).toThrow(ToolValidationError) + }) + }) + + describe('get', () => { + it('retrieves a tool by name', () => { + const tool = createMockTool({ name: 'find-me' }) + registry.add(tool) + expect(registry.get('find-me')).toBe(tool) + }) + + it('returns undefined for a non-existent tool', () => { + expect(registry.get('non-existent')).toBeUndefined() + }) + }) + + describe('resolve', () => { + it('returns the tool for an exact name match', () => { + const tool = createMockTool({ name: 'my-tool' }) + registry.add(tool) + expect(registry.resolve('my-tool')).toBe(tool) + }) + + it('resolves underscore-to-hyphen substitution', () => { + const tool = createMockTool({ name: 'my-tool' }) + registry.add(tool) + expect(registry.resolve('my_tool')).toBe(tool) + }) + + it('resolves case-insensitively', () => { + const tool = createMockTool({ name: 'MyTool' }) + registry.add(tool) + expect(registry.resolve('mytool')).toBe(tool) + }) + + it('prefers exact match over case-insensitive match', () => { + const exact = createMockTool({ name: 'mytool' }) + const cased = createMockTool({ name: 'MYTOOL' }) + // exact must come first because the validator forbids names that differ + // only by '-'/'_'; case-only diffs are allowed. + registry.add([exact, cased]) + expect(registry.resolve('mytool')).toBe(exact) + }) + + it('prefers exact match over underscore-to-hyphen match', () => { + const exact = createMockTool({ name: 'my_tool' }) + registry.add(exact) + // No hyphen variant present — exact is the only candidate. + expect(registry.resolve('my_tool')).toBe(exact) + }) + + it('throws ToolNotFoundError when no tool matches', () => { + registry.add(createMockTool({ name: 'existing-tool' })) + expect(() => registry.resolve('nonexistent')).toThrow(ToolNotFoundError) + }) + + it('attaches the requested name to the thrown ToolNotFoundError', () => { + try { + registry.resolve('missing') + throw new Error('expected resolve() to throw') + } catch (e) { + expect(e).toBeInstanceOf(ToolNotFoundError) + expect((e as ToolNotFoundError).toolName).toBe('missing') + expect((e as ToolNotFoundError).message).toBe("Tool 'missing' not found") + } + }) + + it('throws ToolNotFoundError when registry is empty', () => { + expect(() => registry.resolve('anything')).toThrow(ToolNotFoundError) + }) + }) + + describe('remove', () => { + it('removes a tool by name', () => { + registry.add(createMockTool({ name: 'remove-me' })) + registry.remove('remove-me') + expect(registry.get('remove-me')).toBeUndefined() + }) + + it('does not throw when removing a non-existent tool', () => { + expect(() => registry.remove('non-existent')).not.toThrow() + }) + }) + + describe('list', () => { + it('returns an empty array when no tools are registered', () => { + expect(registry.list()).toStrictEqual([]) + }) + + it('returns all registered tools', () => { + const tool1 = createMockTool({ name: 'tool-1' }) + const tool2 = createMockTool({ name: 'tool-2' }) + registry.add([tool1, tool2]) + expect(registry.list()).toStrictEqual([tool1, tool2]) + }) + }) + + describe('clear', () => { + it('should remove all registered tools', () => { + registry.add([createMockTool({ name: 'tool-1' }), createMockTool({ name: 'tool-2' })]) + registry.clear() + expect(registry.list()).toStrictEqual([]) + }) + + it('should be a no-op on an empty registry', () => { + expect(() => registry.clear()).not.toThrow() + expect(registry.list()).toStrictEqual([]) + }) + }) + + describe('constructor', () => { + it('accepts initial tools', () => { + const tool = createMockTool() + const reg = new ToolRegistry([tool]) + expect(reg.list()).toStrictEqual([tool]) + }) + }) +}) diff --git a/strands-ts/src/registry/tool-registry.ts b/strands-ts/src/registry/tool-registry.ts new file mode 100644 index 0000000000..9dacf17259 --- /dev/null +++ b/strands-ts/src/registry/tool-registry.ts @@ -0,0 +1,161 @@ +import type { Tool } from '../tools/tool.js' +import { ToolValidationError, ToolNotFoundError } from '../errors.js' + +/** + * Registry for managing Tool instances with name-based CRUDL operations. + */ +export class ToolRegistry { + private _tools: Map = new Map() + + /** + * Creates a new ToolRegistry, optionally pre-populated with tools. + * + * @param tools - Optional initial tools to register + */ + constructor(tools?: Tool[]) { + if (tools) { + this.add(tools) + } + } + + /** + * Registers one or more tools. + * + * @param tool - A single tool or array of tools to register + * @throws ToolValidationError If a tool's properties are invalid or its name is already registered + */ + add(tool: Tool | Tool[]): void { + const tools = Array.isArray(tool) ? tool : [tool] + for (const t of tools) { + this._validateProperties(t) + if (this._tools.has(t.name)) { + throw new ToolValidationError(`Tool with name '${t.name}' already registered`) + } + this._checkNormalizedConflict(t.name) + this._tools.set(t.name, t) + } + } + + /** + * Registers one or more tools, replacing any existing tools with the same name. + * + * @param tools - Array of tools to register + * @throws ToolValidationError If a tool's properties are invalid + */ + addOrReplace(newTools: Tool[]): void { + for (const tool of newTools) { + this._validateProperties(tool) + if (!this._tools.has(tool.name)) { + this._checkNormalizedConflict(tool.name) + } + this._tools.set(tool.name, tool) + } + } + + /** + * Retrieves a tool by name. + * + * @param name - The name of the tool to retrieve + * @returns The tool if found, otherwise undefined + */ + get(name: string): Tool | undefined { + return this._tools.get(name) + } + + /** + * Resolves a tool name using normalization strategies and returns the tool. + * + * Resolution order: + * 1. Exact match + * 2. Underscore-to-hyphen substitution (e.g. `my_tool` → `my-tool`) + * 3. Case-insensitive match + * + * @param name - The name to look up + * @returns The resolved tool + * @throws ToolNotFoundError if no tool with the given name exists + */ + resolve(name: string): Tool { + // 1. Direct match + const exact = this._tools.get(name) + if (exact) { + return exact + } + + const tools = this.list() + + // 2. Underscore-to-hyphen normalization + if (name.includes('_')) { + const match = tools.find((t) => t.name.replace(/-/g, '_') === name) + if (match) { + return match + } + } + + // 3. Case-insensitive match + const lowerName = name.toLowerCase() + const caseMatch = tools.find((t) => t.name.toLowerCase() === lowerName) + if (caseMatch) { + return caseMatch + } + + throw new ToolNotFoundError(name) + } + + /** + * Removes a tool by name. No-op if the tool does not exist. + * + * @param name - The name of the tool to remove + */ + remove(name: string): void { + this._tools.delete(name) + } + + /** + * Removes all registered tools. + */ + clear(): void { + this._tools.clear() + } + + /** + * Returns all registered tools. + * + * @returns Array of all registered tools + */ + list(): Tool[] { + return Array.from(this._tools.values()) + } + + private _validateProperties(tool: Tool): void { + if (typeof tool.name !== 'string') { + throw new ToolValidationError('Tool name must be a string') + } + + if (tool.name.length < 1 || tool.name.length > 64) { + throw new ToolValidationError('Tool name must be between 1 and 64 characters') + } + + const validNamePattern = /^[a-zA-Z0-9_-]+$/ + if (!validNamePattern.test(tool.name)) { + throw new ToolValidationError('Tool name must contain only alphanumeric characters, hyphens, and underscores') + } + + if (tool.description !== undefined && tool.description !== null) { + if (typeof tool.description !== 'string' || tool.description.length < 1) { + throw new ToolValidationError('Tool description must be a non-empty string') + } + } + } + + private _checkNormalizedConflict(name: string): void { + const normalized = name.replaceAll('-', '_') + for (const existing of this._tools.keys()) { + if (existing !== name && existing.replaceAll('-', '_') === normalized) { + throw new ToolValidationError( + `Tool name '${name}' already exists as '${existing}'.` + + " Cannot add a duplicate tool which differs by a '-' or '_'" + ) + } + } + } +} diff --git a/strands-ts/src/retry/__tests__/backoff-strategy.test.ts b/strands-ts/src/retry/__tests__/backoff-strategy.test.ts new file mode 100644 index 0000000000..7ce7d57878 --- /dev/null +++ b/strands-ts/src/retry/__tests__/backoff-strategy.test.ts @@ -0,0 +1,138 @@ +// All tests here are purely synchronous — strategies compute a delay number; +// no timers fire and nothing is awaited, so tests never wait real wall time. + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { ConstantBackoff, LinearBackoff, ExponentialBackoff, type BackoffContext } from '../backoff-strategy.js' + +function ctx(partial: Partial = {}): BackoffContext { + return { attempt: 1, elapsedMs: 0, ...partial } +} + +describe('ConstantBackoff', () => { + it('returns the configured delay regardless of attempt', () => { + const bo = new ConstantBackoff({ delayMs: 250 }) + expect(bo.nextDelay(ctx({ attempt: 1 }))).toBe(250) + expect(bo.nextDelay(ctx({ attempt: 5 }))).toBe(250) + }) + + it('defaults delayMs to 1000', () => { + expect(new ConstantBackoff().nextDelay(ctx())).toBe(1000) + }) + + it('rejects attempts below 1', () => { + const bo = new ConstantBackoff() + expect(() => bo.nextDelay(ctx({ attempt: 0 }))).toThrow(/attempt must be an integer/) + }) +}) + +describe('LinearBackoff', () => { + beforeEach(() => { + vi.spyOn(Math, 'random').mockReturnValue(0.5) + }) + afterEach(() => { + vi.restoreAllMocks() + }) + + it('grows as baseMs * attempt', () => { + const bo = new LinearBackoff({ baseMs: 100, jitter: 'none' }) + expect(bo.nextDelay(ctx({ attempt: 1 }))).toBe(100) + expect(bo.nextDelay(ctx({ attempt: 2 }))).toBe(200) + expect(bo.nextDelay(ctx({ attempt: 3 }))).toBe(300) + }) + + it('clamps to maxMs before jitter', () => { + const bo = new LinearBackoff({ baseMs: 1000, maxMs: 2500, jitter: 'none' }) + expect(bo.nextDelay(ctx({ attempt: 10 }))).toBe(2500) + }) + + it('applies full jitter by default (Math.random() * raw)', () => { + const bo = new LinearBackoff({ baseMs: 100 }) + // attempt 4 → raw 400, Math.random() mocked to 0.5 → 200 + expect(bo.nextDelay(ctx({ attempt: 4 }))).toBe(200) + }) + + it('rejects attempts below 1', () => { + const bo = new LinearBackoff() + expect(() => bo.nextDelay(ctx({ attempt: 0 }))).toThrow(/attempt must be an integer/) + }) +}) + +describe('ExponentialBackoff', () => { + beforeEach(() => { + vi.spyOn(Math, 'random').mockReturnValue(0.5) + }) + afterEach(() => { + vi.restoreAllMocks() + }) + + it('grows as baseMs * multiplier^(attempt-1)', () => { + const bo = new ExponentialBackoff({ baseMs: 100, jitter: 'none' }) + expect(bo.nextDelay(ctx({ attempt: 1 }))).toBe(100) + expect(bo.nextDelay(ctx({ attempt: 2 }))).toBe(200) + expect(bo.nextDelay(ctx({ attempt: 3 }))).toBe(400) + expect(bo.nextDelay(ctx({ attempt: 4 }))).toBe(800) + }) + + it('honors custom multiplier', () => { + const bo = new ExponentialBackoff({ baseMs: 100, multiplier: 3, jitter: 'none' }) + expect(bo.nextDelay(ctx({ attempt: 3 }))).toBe(900) + }) + + it('clamps to maxMs', () => { + const bo = new ExponentialBackoff({ baseMs: 100, maxMs: 500, jitter: 'none' }) + expect(bo.nextDelay(ctx({ attempt: 10 }))).toBe(500) + }) + + it('applies full jitter by default', () => { + const bo = new ExponentialBackoff({ baseMs: 100 }) + // attempt 3 → raw 400, Math.random() mocked to 0.5 → 200 + expect(bo.nextDelay(ctx({ attempt: 3 }))).toBe(200) + }) + + it('applies equal jitter as raw/2 + random*raw/2', () => { + const bo = new ExponentialBackoff({ baseMs: 100, jitter: 'equal' }) + // attempt 2 → raw 200, equal → 100 + 0.5*100 = 150 + expect(bo.nextDelay(ctx({ attempt: 2 }))).toBe(150) + }) + + it('applies no jitter when set to none', () => { + const bo = new ExponentialBackoff({ baseMs: 100, jitter: 'none' }) + expect(bo.nextDelay(ctx({ attempt: 3 }))).toBe(400) + }) + + it('falls back to full jitter for decorrelated when lastDelayMs missing', () => { + const bo = new ExponentialBackoff({ baseMs: 100, jitter: 'decorrelated' }) + expect(bo.nextDelay(ctx({ attempt: 3 }))).toBe(200) + }) + + it('applies decorrelated jitter as uniform(baseMs, min(maxMs, lastDelayMs*3))', () => { + const bo = new ExponentialBackoff({ baseMs: 100, maxMs: 10_000, jitter: 'decorrelated' }) + // lastDelayMs=200 → upper=min(10_000, 600)=600 + // random=0.5 → 100 + 0.5 * (600-100) = 350 + expect(bo.nextDelay(ctx({ attempt: 2, lastDelayMs: 200 }))).toBe(350) + }) + + it('caps decorrelated upper at maxMs', () => { + const bo = new ExponentialBackoff({ baseMs: 100, maxMs: 500, jitter: 'decorrelated' }) + // lastDelayMs=1000 → upper=min(500, 3000)=500 + // random=0.5 → 100 + 0.5 * (500-100) = 300 + expect(bo.nextDelay(ctx({ attempt: 2, lastDelayMs: 1000 }))).toBe(300) + }) + + it('floors decorrelated upper at baseMs when lastDelay*3 < baseMs', () => { + // Guards against inverted range: without max(baseMs, ...), upper=30 would be + // below lower=100. The max clamp yields upper=baseMs, so delay stays at baseMs. + const bo = new ExponentialBackoff({ baseMs: 100, maxMs: 500, jitter: 'decorrelated' }) + expect(bo.nextDelay(ctx({ attempt: 2, lastDelayMs: 10 }))).toBe(100) + }) + + it('rejects attempts below 1', () => { + const bo = new ExponentialBackoff() + expect(() => bo.nextDelay(ctx({ attempt: 0 }))).toThrow(/attempt must be an integer/) + }) + + it('rejects non-integer attempts', () => { + const bo = new ExponentialBackoff() + expect(() => bo.nextDelay(ctx({ attempt: 1.5 }))).toThrow(/attempt must be an integer/) + }) +}) diff --git a/strands-ts/src/retry/__tests__/default-model-retry-strategy.test.ts b/strands-ts/src/retry/__tests__/default-model-retry-strategy.test.ts new file mode 100644 index 0000000000..7f2bf4e74d --- /dev/null +++ b/strands-ts/src/retry/__tests__/default-model-retry-strategy.test.ts @@ -0,0 +1,265 @@ +// Tests use vi.useFakeTimers() so the internal `await sleep(...)` never waits +// real wall time — timers are advanced manually with vi.advanceTimersByTimeAsync. + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { DefaultModelRetryStrategy } from '../default-model-retry-strategy.js' +import { ModelRetryStrategy } from '../model-retry-strategy.js' +import type { RetryDecision } from '../retry-strategy.js' +import { ConstantBackoff, type BackoffStrategy } from '../backoff-strategy.js' +import { AfterModelCallEvent } from '../../hooks/events.js' +import { ModelThrottledError } from '../../errors.js' +import { createMockAgent, invokeTrackedHook, type MockAgent } from '../../__fixtures__/agent-helpers.js' + +function makeErrorEvent(agent: MockAgent, error: Error, attemptCount: number): AfterModelCallEvent { + return new AfterModelCallEvent({ agent, model: {} as never, attemptCount, error, invocationState: {} }) +} + +describe('DefaultModelRetryStrategy', () => { + beforeEach(() => { + vi.useFakeTimers() + }) + afterEach(() => { + vi.useRealTimers() + }) + + it('registers an AfterModelCallEvent hook', () => { + const strategy = new DefaultModelRetryStrategy() + const agent = createMockAgent() + strategy.initAgent(agent) + const types = agent.trackedHooks.map((h) => h.eventType) + expect(types).toContain(AfterModelCallEvent) + }) + + it('exposes the plugin name', () => { + expect(new DefaultModelRetryStrategy().name).toBe('strands:default-model-retry-strategy') + }) + + it('is a ModelRetryStrategy', () => { + expect(new DefaultModelRetryStrategy()).toBeInstanceOf(ModelRetryStrategy) + }) + + it('rejects maxAttempts below 1', () => { + expect(() => new DefaultModelRetryStrategy({ maxAttempts: 0 })).toThrow(/maxAttempts/) + }) + + it('sets retry=true on ModelThrottledError and sleeps for the configured delay', async () => { + const strategy = new DefaultModelRetryStrategy({ + maxAttempts: 3, + backoff: new ConstantBackoff({ delayMs: 500 }), + }) + const agent = createMockAgent() + strategy.initAgent(agent) + + const event = makeErrorEvent(agent, new ModelThrottledError('rate limited'), 1) + const pending = invokeTrackedHook(agent, event) + + // Before the timer advances, the hook is still awaiting sleep — retry not yet set. + await vi.advanceTimersByTimeAsync(499) + expect(event.retry).toBeUndefined() + + await vi.advanceTimersByTimeAsync(1) + await pending + expect(event.retry).toBe(true) + }) + + it('does not retry non-retryable errors', async () => { + const strategy = new DefaultModelRetryStrategy({ + backoff: new ConstantBackoff({ delayMs: 10 }), + }) + const agent = createMockAgent() + strategy.initAgent(agent) + + const event = makeErrorEvent(agent, new Error('boom'), 1) + await invokeTrackedHook(agent, event) + expect(event.retry).toBeUndefined() + }) + + it('stops retrying once maxAttempts is reached', async () => { + const strategy = new DefaultModelRetryStrategy({ + maxAttempts: 3, + backoff: new ConstantBackoff({ delayMs: 1 }), + }) + const agent = createMockAgent() + strategy.initAgent(agent) + + // Attempt 1 → retry + const e1 = makeErrorEvent(agent, new ModelThrottledError('x'), 1) + const p1 = invokeTrackedHook(agent, e1) + await vi.advanceTimersByTimeAsync(1) + await p1 + expect(e1.retry).toBe(true) + + // Attempt 2 → retry + const e2 = makeErrorEvent(agent, new ModelThrottledError('x'), 2) + const p2 = invokeTrackedHook(agent, e2) + await vi.advanceTimersByTimeAsync(1) + await p2 + expect(e2.retry).toBe(true) + + // Attempt 3 → at max, should not retry + const e3 = makeErrorEvent(agent, new ModelThrottledError('x'), 3) + await invokeTrackedHook(agent, e3) + expect(e3.retry).toBeUndefined() + }) + + it('skips work if another hook already requested retry', async () => { + const strategy = new DefaultModelRetryStrategy({ + maxAttempts: 5, + backoff: new ConstantBackoff({ delayMs: 1000 }), + }) + const agent = createMockAgent() + strategy.initAgent(agent) + + const event = makeErrorEvent(agent, new ModelThrottledError('x'), 1) + event.retry = true + + // Should return immediately with no sleep — if it tried to sleep we'd see + // hung test state; resolving without advancing timers proves the skip. + await invokeTrackedHook(agent, event) + expect(event.retry).toBe(true) + }) + + it('clears backoff state at the start of each new turn', async () => { + // The strategy resets state on `attemptCount === 1` regardless of how + // the prior turn ended. This exercises that: a turn racks up a retry + // (lastDelayMs = 5), then the next turn's first attempt must see a + // fresh BackoffContext (no lastDelayMs). + const nextDelay = vi.fn().mockReturnValue(5) + const backoff: BackoffStrategy = { nextDelay } + const strategy = new DefaultModelRetryStrategy({ maxAttempts: 5, backoff }) + const agent = createMockAgent() + strategy.initAgent(agent) + + // Turn 1 → fail → lastDelayMs gets set to 5. + const e1 = makeErrorEvent(agent, new ModelThrottledError('x'), 1) + const p1 = invokeTrackedHook(agent, e1) + await vi.advanceTimersByTimeAsync(5) + await p1 + + // Turn 2 → fail on first attempt → should see no lastDelayMs. + const e2 = makeErrorEvent(agent, new ModelThrottledError('x'), 1) + const p2 = invokeTrackedHook(agent, e2) + await vi.advanceTimersByTimeAsync(5) + await p2 + + expect(nextDelay.mock.calls[1]![0]).toEqual({ + attempt: 1, + elapsedMs: expect.any(Number), + }) + }) + + it('passes BackoffContext with attempt and lastDelayMs to the backoff strategy', async () => { + const nextDelay = vi.fn().mockReturnValue(5) + const backoff: BackoffStrategy = { nextDelay } + const strategy = new DefaultModelRetryStrategy({ maxAttempts: 5, backoff }) + const agent = createMockAgent() + strategy.initAgent(agent) + + const e1 = makeErrorEvent(agent, new ModelThrottledError('x'), 1) + const p1 = invokeTrackedHook(agent, e1) + await vi.advanceTimersByTimeAsync(5) + await p1 + + expect(nextDelay).toHaveBeenCalledTimes(1) + expect(nextDelay.mock.calls[0]![0]).toEqual({ + attempt: 1, + elapsedMs: expect.any(Number), + }) + + const e2 = makeErrorEvent(agent, new ModelThrottledError('x'), 2) + const p2 = invokeTrackedHook(agent, e2) + await vi.advanceTimersByTimeAsync(5) + await p2 + + expect(nextDelay).toHaveBeenCalledTimes(2) + expect(nextDelay.mock.calls[1]![0]).toEqual({ + attempt: 2, + elapsedMs: expect.any(Number), + lastDelayMs: 5, + }) + }) + + it('clears per-turn state on attempt 1 even when a prior hook already set event.retry', async () => { + // Regression: onFirstModelAttempt must fire before the event.retry short-circuit. + // Otherwise state from a prior turn leaks into the new turn's BackoffContext. + const nextDelay = vi.fn().mockReturnValue(5) + const backoff: BackoffStrategy = { nextDelay } + const strategy = new DefaultModelRetryStrategy({ maxAttempts: 5, backoff }) + const agent = createMockAgent() + strategy.initAgent(agent) + + // Turn 1 → fail → lastDelayMs gets set to 5. + const e1 = makeErrorEvent(agent, new ModelThrottledError('x'), 1) + const p1 = invokeTrackedHook(agent, e1) + await vi.advanceTimersByTimeAsync(5) + await p1 + + // Turn 2 → attempt 1 → another hook already set retry=true before us. + // We should still clear state (onFirstModelAttempt runs first), even though + // we short-circuit and don't call computeRetryDecision. + const e2 = makeErrorEvent(agent, new ModelThrottledError('x'), 1) + e2.retry = true + await invokeTrackedHook(agent, e2) + + // Turn 2 → attempt 2 → backoff should see no lastDelayMs from turn 1. + const e3 = makeErrorEvent(agent, new ModelThrottledError('x'), 2) + const p3 = invokeTrackedHook(agent, e3) + await vi.advanceTimersByTimeAsync(5) + await p3 + + // Second call is turn 2 attempt 2; must not carry turn 1's lastDelayMs. + expect(nextDelay.mock.calls[1]![0]).toEqual({ + attempt: 2, + elapsedMs: expect.any(Number), + }) + }) + + it('lets subclasses expand the retryable set by overriding isRetryable', async () => { + class CustomError extends Error {} + + class PermissiveStrategy extends DefaultModelRetryStrategy { + override readonly name = 'test:permissive' + protected override isRetryable(error: Error): boolean { + return super.isRetryable(error) || error instanceof CustomError + } + } + + const strategy = new PermissiveStrategy({ + maxAttempts: 3, + backoff: new ConstantBackoff({ delayMs: 10 }), + }) + const agent = createMockAgent() + strategy.initAgent(agent) + + const event = makeErrorEvent(agent, new CustomError('custom'), 1) + const pending = invokeTrackedHook(agent, event) + await vi.advanceTimersByTimeAsync(10) + await pending + + expect(event.retry).toBe(true) + }) + + it('short-circuits without retry when computeRetryDecision returns retry:false for a non-max reason', async () => { + // Exercises the computeRetryDecision "return { retry: false }" branch that + // isn't about maxAttempts. A subclass declines to retry a specific error + // instance even though the classifier said it was retryable in principle. + class PickyStrategy extends DefaultModelRetryStrategy { + override readonly name = 'test:picky' + protected override computeRetryDecision(event: AfterModelCallEvent): RetryDecision { + if ((event.error as Error).message === 'skip') return { retry: false } + return super.computeRetryDecision(event) + } + } + + const strategy = new PickyStrategy({ + maxAttempts: 5, + backoff: new ConstantBackoff({ delayMs: 10 }), + }) + const agent = createMockAgent() + strategy.initAgent(agent) + + const event = makeErrorEvent(agent, new ModelThrottledError('skip'), 1) + await invokeTrackedHook(agent, event) + expect(event.retry).toBeUndefined() + }) +}) diff --git a/strands-ts/src/retry/backoff-strategy.ts b/strands-ts/src/retry/backoff-strategy.ts new file mode 100644 index 0000000000..bb7ee00300 --- /dev/null +++ b/strands-ts/src/retry/backoff-strategy.ts @@ -0,0 +1,171 @@ +/** + * Backoff strategies for computing delay between retry attempts. + * + * A `BackoffStrategy` is pure delay math: given a `BackoffContext`, it returns + * how long to wait before the next attempt. Policy concerns — whether to retry, + * whether to honor a server-provided `Retry-After` hint, max attempts, total + * time budgets — live in the retry orchestration layer, not here. + */ + +/** + * Context passed to a {@link BackoffStrategy} for each retry decision. + * + * Treated as an open, additive-only contract: new optional fields may be added + * over time, but existing fields will not be removed or repurposed. + */ +export interface BackoffContext { + /** 1-based index of the attempt that just failed. Must be \>= 1. */ + attempt: number + /** Total milliseconds elapsed since the first attempt started. */ + elapsedMs: number + /** Previously computed delay, if any. Absent before the first retry. */ + lastDelayMs?: number +} + +/** + * Computes the delay before the next retry attempt. + */ +export interface BackoffStrategy { + /** + * Returns the delay in milliseconds before the next attempt. + * + * Must be a non-negative finite number. Implementations should treat + * `ctx.attempt < 1` as a programmer error. + */ + nextDelay(ctx: BackoffContext): number +} + +/** + * Supported jitter modes. + * + * - `none`: return the raw delay unchanged + * - `full`: uniform random in `[0, raw]` + * - `equal`: `raw/2 + uniform(0, raw/2)` (half fixed, half random) + * - `decorrelated`: `uniform(baseMs, lastDelayMs * 3)`, capped at `maxMs`; + * falls back to `full` on the first retry when `lastDelayMs` is unavailable + * + * For jitter outside these modes, implement {@link BackoffStrategy} directly. + */ +export type JitterKind = 'none' | 'full' | 'equal' | 'decorrelated' + +function validateAttempt(attempt: number, className: string): void { + if (!Number.isInteger(attempt) || attempt < 1) { + throw new Error(`${className}: attempt must be an integer >= 1 (got ${attempt})`) + } +} + +/** + * Options for {@link ConstantBackoff}. + */ +export interface ConstantBackoffOptions { + /** Delay in ms returned for every retry. Default 1000. */ + delayMs?: number +} + +/** + * Constant backoff: returns the same delay for every retry. + */ +export class ConstantBackoff implements BackoffStrategy { + private readonly _delayMs: number + + constructor(opts: ConstantBackoffOptions = {}) { + this._delayMs = opts.delayMs ?? 1000 + } + + nextDelay(ctx: BackoffContext): number { + validateAttempt(ctx.attempt, 'ConstantBackoff') + return this._delayMs + } +} + +/** + * Options for {@link LinearBackoff}. + */ +export interface LinearBackoffOptions { + /** Base delay in ms. Delay grows as `baseMs * attempt`. Default 1000. */ + baseMs?: number + /** Upper bound applied before jitter. Default 30_000. */ + maxMs?: number + /** Jitter mode. Default 'full'. */ + jitter?: JitterKind +} + +/** + * Linear backoff: delay grows as `baseMs * attempt`, capped at `maxMs`, then jittered. + */ +export class LinearBackoff implements BackoffStrategy { + private readonly _baseMs: number + private readonly _maxMs: number + private readonly _jitter: JitterKind + + constructor(opts: LinearBackoffOptions = {}) { + this._baseMs = opts.baseMs ?? 1000 + this._maxMs = opts.maxMs ?? 30_000 + this._jitter = opts.jitter ?? 'full' + } + + nextDelay(ctx: BackoffContext): number { + validateAttempt(ctx.attempt, 'LinearBackoff') + const raw = Math.min(this._maxMs, this._baseMs * ctx.attempt) + return jitter(raw, this._jitter, this._baseMs, this._maxMs, ctx.lastDelayMs) + } +} + +/** + * Options for {@link ExponentialBackoff}. + */ +export interface ExponentialBackoffOptions { + /** Base delay in ms. Delay grows as `baseMs * multiplier^(attempt-1)`. Default 1000. */ + baseMs?: number + /** Upper bound applied before jitter. Default 30_000. */ + maxMs?: number + /** Growth factor per attempt. Default 2. */ + multiplier?: number + /** Jitter mode. Default 'full'. */ + jitter?: JitterKind +} + +/** + * Exponential backoff: delay grows as `baseMs * multiplier^(attempt-1)`, + * capped at `maxMs`, then jittered. + */ +export class ExponentialBackoff implements BackoffStrategy { + private readonly _baseMs: number + private readonly _maxMs: number + private readonly _multiplier: number + private readonly _jitter: JitterKind + + constructor(opts: ExponentialBackoffOptions = {}) { + this._baseMs = opts.baseMs ?? 1000 + this._maxMs = opts.maxMs ?? 30_000 + this._multiplier = opts.multiplier ?? 2 + this._jitter = opts.jitter ?? 'full' + } + + nextDelay(ctx: BackoffContext): number { + validateAttempt(ctx.attempt, 'ExponentialBackoff') + const raw = Math.min(this._maxMs, this._baseMs * this._multiplier ** (ctx.attempt - 1)) + return jitter(raw, this._jitter, this._baseMs, this._maxMs, ctx.lastDelayMs) + } +} + +function jitter(raw: number, kind: JitterKind, baseMs: number, maxMs: number, lastDelayMs?: number): number { + switch (kind) { + case 'none': + return raw + case 'full': + return Math.random() * raw + case 'equal': + return raw / 2 + Math.random() * (raw / 2) + case 'decorrelated': { + if (lastDelayMs === undefined) { + return Math.random() * raw + } + // Standard decorrelated jitter: uniform(baseMs, min(maxMs, lastDelay * 3)). + // The max() guards against the degenerate case where maxMs < baseMs, + // which would otherwise produce an inverted range. + const upper = Math.max(baseMs, Math.min(maxMs, lastDelayMs * 3)) + return baseMs + Math.random() * (upper - baseMs) + } + } +} diff --git a/strands-ts/src/retry/default-model-retry-strategy.ts b/strands-ts/src/retry/default-model-retry-strategy.ts new file mode 100644 index 0000000000..ca2453cca4 --- /dev/null +++ b/strands-ts/src/retry/default-model-retry-strategy.ts @@ -0,0 +1,141 @@ +/** + * Default concrete retry strategy for model invocations. + * + * Implements {@link ModelRetryStrategy.computeRetryDecision} to retry failed model + * calls classified by {@link isRetryable}, bounded by `maxAttempts`, with + * delays computed by the configured {@link BackoffStrategy}. + * + * The attempt counter lives on {@link AfterModelCallEvent.attemptCount}, + * maintained by the agent loop. This strategy only keeps per-turn backoff + * state (first-failure timestamp, last delay), which is cleared in + * {@link onFirstModelAttempt}. + */ + +import type { AfterModelCallEvent } from '../hooks/events.js' +import { ModelThrottledError } from '../errors.js' +import { logger } from '../logging/logger.js' +import type { BackoffContext, BackoffStrategy } from './backoff-strategy.js' +import { ExponentialBackoff } from './backoff-strategy.js' +import { ModelRetryStrategy } from './model-retry-strategy.js' +import type { RetryDecision } from './retry-strategy.js' + +const DEFAULT_MAX_ATTEMPTS = 6 +const DEFAULT_BACKOFF_BASE_MS = 4_000 +const DEFAULT_BACKOFF_MAX_MS = 240_000 + +/** + * Options for {@link DefaultModelRetryStrategy}. + */ +export interface DefaultModelRetryStrategyOptions { + /** + * Total model attempts before giving up and re-raising the error. + * Must be \>= 1. Default {@link DEFAULT_MAX_ATTEMPTS}. + */ + maxAttempts?: number + /** + * Backoff used to compute the delay between retries. + * Default: `new ExponentialBackoff({ baseMs: DEFAULT_BACKOFF_BASE_MS, maxMs: DEFAULT_BACKOFF_MAX_MS })`. + */ + backoff?: BackoffStrategy +} + +/** + * Retries failed model calls classified by the SDK as retryable. + * + * Today, only {@link ModelThrottledError} is treated as retryable — subclass + * and override {@link isRetryable} to expand or narrow that set without + * reimplementing the rest of the retry policy. + * + * State is per-turn: backoff timing state resets in {@link onFirstModelAttempt}, + * which the base class calls when `event.attemptCount === 1`. The attempt + * counter itself is owned by the agent loop and read off + * {@link AfterModelCallEvent.attemptCount}. + * + * Hook precedence: {@link AfterModelCallEvent} fires hooks in reverse registration + * order, so user-registered hooks run before this strategy. If a user hook sets + * `event.retry = true` first, the base class returns early and does not stack + * additional backoff on top. + * + * Sharing: a given instance tracks its own backoff state and must not be shared + * across multiple agents. Create a separate instance per agent. + * + * @example + * ```ts + * const agent = new Agent({ + * model, + * retryStrategy: new DefaultModelRetryStrategy({ maxAttempts: 4 }), + * }) + * ``` + */ +export class DefaultModelRetryStrategy extends ModelRetryStrategy { + readonly name: string = 'strands:default-model-retry-strategy' + + private readonly _maxAttempts: number + private readonly _backoff: BackoffStrategy + + private _lastDelayMs: number | undefined + private _firstFailureAt: number | undefined + + constructor(opts: DefaultModelRetryStrategyOptions = {}) { + super() + const maxAttempts = opts.maxAttempts ?? DEFAULT_MAX_ATTEMPTS + if (!Number.isInteger(maxAttempts) || maxAttempts < 1) { + throw new Error(`DefaultModelRetryStrategy: maxAttempts must be an integer >= 1 (got ${maxAttempts})`) + } + this._maxAttempts = maxAttempts + this._backoff = + opts.backoff ?? new ExponentialBackoff({ baseMs: DEFAULT_BACKOFF_BASE_MS, maxMs: DEFAULT_BACKOFF_MAX_MS }) + } + + /** + * Whether `error` should be retried. Override to extend or narrow the + * retryable set (e.g. to also retry transient 5xx errors). + */ + protected isRetryable(error: Error): boolean { + return error instanceof ModelThrottledError + } + + protected override computeRetryDecision(event: AfterModelCallEvent): RetryDecision { + const error = event.error + if (error === undefined || !this.isRetryable(error)) { + return { retry: false } + } + + if (event.attemptCount >= this._maxAttempts) { + logger.debug( + `attempt_count=<${event.attemptCount}> max_attempts=<${this._maxAttempts}> | max retry attempts reached` + ) + return { retry: false } + } + + if (this._firstFailureAt === undefined) { + this._firstFailureAt = Date.now() + } + + const waitMs = this._backoff.nextDelay(this._buildContext(event.attemptCount)) + + logger.debug( + `retry_delay_ms=<${waitMs}> attempt_count=<${event.attemptCount}> max_attempts=<${this._maxAttempts}> ` + + `| retryable model error, delaying before retry` + ) + + this._lastDelayMs = waitMs + return { retry: true, waitMs } + } + + protected override onFirstModelAttempt(): void { + this._lastDelayMs = undefined + this._firstFailureAt = undefined + } + + private _buildContext(attemptCount: number): BackoffContext { + const ctx: BackoffContext = { + attempt: attemptCount, + elapsedMs: this._firstFailureAt === undefined ? 0 : Date.now() - this._firstFailureAt, + } + if (this._lastDelayMs !== undefined) { + ctx.lastDelayMs = this._lastDelayMs + } + return ctx + } +} diff --git a/strands-ts/src/retry/index.ts b/strands-ts/src/retry/index.ts new file mode 100644 index 0000000000..45236c23e4 --- /dev/null +++ b/strands-ts/src/retry/index.ts @@ -0,0 +1,21 @@ +/** + * Retry utilities. + */ + +export { + type BackoffContext, + type BackoffStrategy, + type JitterKind, + type ConstantBackoffOptions, + type LinearBackoffOptions, + type ExponentialBackoffOptions, + ConstantBackoff, + LinearBackoff, + ExponentialBackoff, +} from './backoff-strategy.js' + +export { ModelRetryStrategy } from './model-retry-strategy.js' + +export { DefaultModelRetryStrategy, type DefaultModelRetryStrategyOptions } from './default-model-retry-strategy.js' + +export type { RetryStrategy, RetryDecision } from './retry-strategy.js' diff --git a/strands-ts/src/retry/model-retry-strategy.ts b/strands-ts/src/retry/model-retry-strategy.ts new file mode 100644 index 0000000000..bf7e21787d --- /dev/null +++ b/strands-ts/src/retry/model-retry-strategy.ts @@ -0,0 +1,114 @@ +/** + * Abstract base class for model-retry strategies. + */ + +import { AfterModelCallEvent } from '../hooks/events.js' +import type { Plugin } from '../plugins/plugin.js' +import type { LocalAgent } from '../types/agent.js' +import type { RetryDecision } from './retry-strategy.js' + +/** + * Abstract base class for model-retry strategies. + * + * A {@link ModelRetryStrategy} is a {@link Plugin} that retries failed model + * calls. Subclasses implement {@link computeRetryDecision} to answer *whether* to retry + * and *how long* to wait; the base class orchestrates the rest: + * + * 1. Short-circuits if another hook already set `event.retry` (no stacked delay). + * 2. Short-circuits on success events (`event.error === undefined`). + * 3. Calls {@link onFirstModelAttempt} on turn boundaries (`event.attemptCount === 1`), + * letting stateful subclasses clear per-turn state. + * 4. Invokes {@link computeRetryDecision}; on `retry: true`, sleeps for `waitMs` then + * sets `event.retry = true`. + * + * Other retry kinds (e.g. tool retries) will land as *sibling* abstract + * classes, not as additional methods on this one — different retry kinds + * have different unit-of-work boundaries and don't share a single state + * contract. + * + * Single-agent attachment: instances typically carry per-turn state, so + * sharing one instance across agents would let their calls trample each + * other. The base class throws on attempts to attach to a different agent. + */ +export abstract class ModelRetryStrategy implements Plugin { + /** + * A stable string identifier for this retry strategy. + */ + abstract readonly name: string + + private _attachedAgent: LocalAgent | undefined + + /** + * Decide whether to retry the failed model call, and how long to wait first. + * + * Called only for error events that have not already been marked for retry + * by another hook. The base class has already filtered out successes and + * short-circuited events where `event.retry` is true, so implementations + * only need to reason about `event.error`. + * + * Return `{ retry: false }` to let the error propagate. Return + * `{ retry: true, waitMs }` to retry after sleeping for `waitMs` + * milliseconds. + */ + protected abstract computeRetryDecision(event: AfterModelCallEvent): RetryDecision | Promise + + /** + * Called when `event.attemptCount === 1`, i.e. at the start of a fresh + * turn. Subclasses with per-turn state override this to clear it; the + * default is a no-op. + * + * The agent loop guarantees `attemptCount === 1` on every new turn, so + * this is a reliable turn-boundary signal. + */ + protected onFirstModelAttempt(): void {} + + /** + * @internal + * Hook callback invoked by the agent on every {@link AfterModelCallEvent}. + * Subclasses should override {@link computeRetryDecision} or + * {@link onFirstModelAttempt} instead of this method. + */ + async retryModel(event: AfterModelCallEvent): Promise { + // Fire the turn-boundary signal before any short-circuit so per-turn state + // always clears at the start of a new turn, even if a user hook already + // set event.retry on attempt 1. + if (event.attemptCount === 1) this.onFirstModelAttempt() + + if (event.retry) return + if (event.error === undefined) return + + const decision = await this.computeRetryDecision(event) + if (!decision.retry) return + + await sleep(decision.waitMs) + event.retry = true + } + + /** + * Initialize the retry strategy with the agent instance. + * + * Enforces the single-agent attachment guard and registers the + * {@link AfterModelCallEvent} hook that drives retry orchestration. + * + * Subclasses that override this method MUST call `super.initAgent(agent)` + * to preserve the attachment guard and hook registration. Additional + * hooks may be registered after the `super` call. + * + * @param agent - The agent to register hooks with + */ + initAgent(agent: LocalAgent): void { + if (this._attachedAgent !== undefined && this._attachedAgent !== agent) { + throw new Error( + `${this.constructor.name}: instance is already attached to another agent. ` + + 'Create a separate instance per agent.' + ) + } + this._attachedAgent = agent + + agent.addHook(AfterModelCallEvent, (event) => this.retryModel(event)) + } +} + +function sleep(ms: number): Promise { + return new Promise((resolve) => globalThis.setTimeout(resolve, ms)) +} diff --git a/strands-ts/src/retry/retry-strategy.ts b/strands-ts/src/retry/retry-strategy.ts new file mode 100644 index 0000000000..00df36d813 --- /dev/null +++ b/strands-ts/src/retry/retry-strategy.ts @@ -0,0 +1,45 @@ +/** + * Shared retry primitives. + */ + +import { logger } from '../logging/logger.js' +import type { ModelRetryStrategy } from './model-retry-strategy.js' + +/** + * Any retry strategy accepted by the agent. + */ +// Today this only admits model retries. Future retry kinds (e.g. tool retries) +// will be added as additional arms of this union. +export type RetryStrategy = ModelRetryStrategy + +/** + * Decision returned by a retry strategy's per-event `compute*RetryDecision` method. + * + * Discriminated union: `retry: true` carries the wait duration the framework + * will sleep for before re-invoking the failed operation. `retry: false` + * carries nothing — the error propagates to the caller. + * + * Shared across retry kinds (model retries today; tool retries and others + * later) so all strategies speak the same decision shape. + */ +export type RetryDecision = { retry: false } | { retry: true; waitMs: number } + +/** + * Emit a warning for each duplicate-type retry strategy in the list. + * + * Two strategies of the same concrete class share the same `plugin.name` + * and would otherwise collide in the plugin registry. This is a warning, + * not an error — the caller decides how to handle duplicates (e.g. keep + * the first, drop the rest). + */ +export function warnOnDuplicateRetryStrategyTypes(strategies: readonly RetryStrategy[]): void { + const seen = new Set RetryStrategy>() + for (const strategy of strategies) { + const ctor = strategy.constructor as new (...args: never[]) => RetryStrategy + if (seen.has(ctor)) { + logger.warn(`retry_strategy_type=<${ctor.name}> | multiple instances provided; only the first will be used`) + } else { + seen.add(ctor) + } + } +} diff --git a/strands-ts/src/sandbox/__tests__/posix-shell.test.node.ts b/strands-ts/src/sandbox/__tests__/posix-shell.test.node.ts new file mode 100644 index 0000000000..b58d8784f2 --- /dev/null +++ b/strands-ts/src/sandbox/__tests__/posix-shell.test.node.ts @@ -0,0 +1,292 @@ +import { describe, it, expect, beforeEach, afterEach } from 'vitest' +import fs from 'fs' +import { TestSandbox } from '../../__fixtures__/test-sandbox.node.js' +import { streamProcess } from '../stream-process.js' +import type { ExecutionResult, StreamChunk } from '../types.js' + +const TEST_DIR = '/tmp/strands-test-shell-sandbox' + +describe.skipIf(process.platform === 'win32')('PosixShellSandbox', () => { + let sandbox: TestSandbox + + beforeEach(() => { + fs.rmSync(TEST_DIR, { recursive: true, force: true }) + fs.mkdirSync(TEST_DIR, { recursive: true }) + sandbox = new TestSandbox(TEST_DIR) + }) + + afterEach(() => { + fs.rmSync(TEST_DIR, { recursive: true, force: true }) + }) + + describe('execute (via shell commands)', () => { + it('runs a command', async () => { + const result = await sandbox.execute('echo hello') + expect(result.exitCode).toBe(0) + expect(result.stdout).toBe('hello\n') + }) + + it('runs in workingDir', async () => { + const result = await sandbox.execute('pwd') + expect(result.stdout.trim()).toContain('strands-test-shell-sandbox') + }) + + it('respects cwd option', async () => { + const result = await sandbox.execute('pwd', { cwd: '/tmp' }) + expect(result.stdout.trim()).toMatch(/\/tmp$/) + }) + }) + + describe('executeCode (via shell quoting)', () => { + it('runs python code through shell', async () => { + const result = await sandbox.executeCode('print(2 + 2)', 'python3') + expect(result.exitCode).toBe(0) + expect(result.stdout).toBe('4\n') + }) + + it('handles code with special characters', async () => { + const result = await sandbox.executeCode('print(\'hello "world"\')', 'python3') + expect(result.stdout).toBe('hello "world"\n') + }) + + it('handles code with single quotes', async () => { + const result = await sandbox.executeCode('print("it\'s working")', 'python3') + expect(result.stdout).toBe("it's working\n") + }) + }) + + describe('language validation', () => { + it('rejects path traversal', async () => { + await expect(sandbox.executeCode('x', '../../../bin/sh')).rejects.toThrow('invalid characters') + }) + + it('rejects shell metacharacters', async () => { + await expect(sandbox.executeCode('x', 'python;rm -rf /')).rejects.toThrow('invalid characters') + }) + + it('rejects spaces', async () => { + await expect(sandbox.executeCode('x', 'python -c')).rejects.toThrow('invalid characters') + }) + + it('allows valid interpreters', async () => { + const result = await sandbox.executeCode('print("safe")', 'python3') + expect(result.exitCode).toBe(0) + }) + + it('allows dots and hyphens', async () => { + const result = await sandbox.executeCode('x', 'fake-lang.99') + expect(result.exitCode).toBe(127) + }) + }) + + describe('read/write (via base64 encoding over shell)', () => { + it('text file roundtrip', async () => { + await sandbox.writeText('test.txt', 'hello shell') + const text = await sandbox.readText('test.txt') + expect(text).toBe('hello shell') + }) + + it('binary file roundtrip', async () => { + const bytes = new Uint8Array([0, 1, 2, 127, 128, 254, 255]) + await sandbox.writeFile('binary.bin', bytes) + const read = await sandbox.readFile('binary.bin') + expect(Array.from(read)).toStrictEqual(Array.from(bytes)) + }) + + it('all 256 byte values roundtrip', async () => { + const bytes = new Uint8Array(256) + for (let i = 0; i < 256; i++) bytes[i] = i + await sandbox.writeFile('all-bytes.bin', bytes) + const read = await sandbox.readFile('all-bytes.bin') + expect(Array.from(read)).toStrictEqual(Array.from(bytes)) + }) + + it('creates parent directories', async () => { + await sandbox.writeText('deep/nested/file.txt', 'deep') + const text = await sandbox.readText('deep/nested/file.txt') + expect(text).toBe('deep') + }) + + it('handles unicode content', async () => { + const content = '日本語 🚀 émojis' + await sandbox.writeText('unicode.txt', content) + const text = await sandbox.readText('unicode.txt') + expect(text).toBe(content) + }) + + it('handles shell metacharacters in content', async () => { + const content = '$(rm -rf /) `whoami` && || $HOME' + await sandbox.writeText('meta.txt', content) + const text = await sandbox.readText('meta.txt') + expect(text).toBe(content) + }) + + it('throws on nonexistent file', async () => { + await expect(sandbox.readFile('nope.txt')).rejects.toThrow() + }) + }) + + describe('remove', () => { + it('removes a file', async () => { + await sandbox.writeText('delete-me.txt', 'bye') + await sandbox.removeFile('delete-me.txt') + await expect(sandbox.readFile('delete-me.txt')).rejects.toThrow() + }) + + it('throws on nonexistent file', async () => { + await expect(sandbox.removeFile('nope.txt')).rejects.toThrow() + }) + }) + + describe('list (via ls -1ap parsing)', () => { + it('lists directory contents', async () => { + await sandbox.writeText('a.txt', 'a') + await sandbox.writeText('b.txt', 'b') + const files = await sandbox.listFiles('.') + const names = files.map((f) => f.name) + expect(names).toContain('a.txt') + expect(names).toContain('b.txt') + }) + + it('identifies directories', async () => { + await sandbox.execute('mkdir -p subdir') + const files = await sandbox.listFiles('.') + const subdir = files.find((f) => f.name === 'subdir') + expect(subdir?.isDir).toBe(true) + }) + + it('excludes . and .. entries', async () => { + await sandbox.writeText('file.txt', '') + const files = await sandbox.listFiles('.') + const names = files.map((f) => f.name) + expect(names).not.toContain('.') + expect(names).not.toContain('..') + }) + + it('throws on nonexistent directory', async () => { + await expect(sandbox.listFiles('/tmp/nonexistent-dir-xyz')).rejects.toThrow() + }) + + it('throws when path is a file, not a directory', async () => { + await sandbox.writeText('not-a-dir.txt', 'hello') + await expect(sandbox.listFiles('not-a-dir.txt')).rejects.toThrow() + }) + }) + + describe('shellQuote', () => { + it('handles paths with spaces', async () => { + await sandbox.execute('mkdir -p "with spaces"') + await sandbox.writeText('with spaces/file.txt', 'spaced') + const text = await sandbox.readText('with spaces/file.txt') + expect(text).toBe('spaced') + }) + + it('handles paths with single quotes', async () => { + await sandbox.execute('mkdir -p "it\'s"') + await sandbox.writeText("it's/file.txt", 'quoted') + const text = await sandbox.readText("it's/file.txt") + expect(text).toBe('quoted') + }) + }) + + describe('timeout', () => { + it('kills process on timeout', async () => { + const start = Date.now() + await expect(sandbox.execute('sleep 60', { timeout: 0.2 })).rejects.toThrow('timed out') + const elapsed = Date.now() - start + expect(elapsed).toBeLessThan(2000) + }) + + it('does not timeout fast commands', async () => { + const result = await sandbox.execute('echo fast', { timeout: 5 }) + expect(result.exitCode).toBe(0) + expect(result.stdout).toBe('fast\n') + }) + }) + + describe('abort signal', () => { + it('kills process when signal is aborted', async () => { + const controller = new AbortController() + const promise = sandbox.execute('sleep 60', { signal: controller.signal }) + setTimeout(() => controller.abort(), 100) + await expect(promise).rejects.toThrow('aborted') + }) + + it('rejects immediately if signal is already aborted', async () => { + const controller = new AbortController() + controller.abort() + await expect(sandbox.execute('sleep 60', { signal: controller.signal })).rejects.toThrow('aborted') + }) + }) + + describe('concurrent execution', () => { + it('handles multiple concurrent commands', async () => { + const results = await Promise.all([ + sandbox.execute('echo one'), + sandbox.execute('echo two'), + sandbox.execute('echo three'), + ]) + expect(results.map((r) => r.stdout.trim()).sort()).toStrictEqual(['one', 'three', 'two']) + }) + + it('handles concurrent file writes to different files', async () => { + await Promise.all([ + sandbox.writeText('a.txt', 'aaa'), + sandbox.writeText('b.txt', 'bbb'), + sandbox.writeText('c.txt', 'ccc'), + ]) + const [a, b, c] = await Promise.all([ + sandbox.readText('a.txt'), + sandbox.readText('b.txt'), + sandbox.readText('c.txt'), + ]) + expect(a).toBe('aaa') + expect(b).toBe('bbb') + expect(c).toBe('ccc') + }) + }) + + describe('streaming', () => { + it('yields StreamChunks then ExecutionResult', async () => { + const chunks: Array<{ type: string }> = [] + for await (const chunk of sandbox.executeStreaming('echo hello')) { + chunks.push(chunk) + } + const streamChunks = chunks.filter((c) => c.type === 'streamChunk') + const results = chunks.filter((c) => c.type === 'executionResult') + expect(streamChunks.length).toBeGreaterThan(0) + expect(results).toHaveLength(1) + }) + }) + + describe('streamProcess edge cases', () => { + it('returns exit code 127 when command is not found', async () => { + const result = await sandbox.execute('nonexistent_binary_xyz_12345') + expect(result.exitCode).toBe(127) + expect(result.stderr).toContain('not found') + }) + + it('maps signal termination to 128 + signal number', async () => { + // sh -c 'kill -9 $$' sends SIGKILL to itself → exit code 128 + 9 = 137 + const result = await sandbox.execute("sh -c 'kill -9 $$'") + expect(result.exitCode).toBe(137) + }) + + it('returns enoentMessage when spawned binary does not exist', async () => { + const chunks: (StreamChunk | ExecutionResult)[] = [] + for await (const chunk of streamProcess('nonexistent_binary_xyz_12345', [], { + enoentMessage: 'binary not found', + })) { + chunks.push(chunk) + } + const result = chunks.find((c): c is ExecutionResult => c.type === 'executionResult') + expect(result).toStrictEqual({ + type: 'executionResult', + exitCode: 127, + stdout: '', + stderr: 'binary not found', + outputFiles: [], + }) + }) + }) +}) diff --git a/strands-ts/src/sandbox/base.ts b/strands-ts/src/sandbox/base.ts new file mode 100644 index 0000000000..1b5a8806d4 --- /dev/null +++ b/strands-ts/src/sandbox/base.ts @@ -0,0 +1,173 @@ +/** + * Base sandbox interface. + * + * Defines the abstract {@link Sandbox} class that all sandbox implementations + * must extend. The class provides six abstract operations (command execution, + * code execution, and file I/O) and convenience wrappers for common patterns. + */ + +import type { ExecutionResult, FileInfo, StreamChunk } from './types.js' + +/** + * Options for command and code execution. + */ +export interface ExecuteOptions { + /** Maximum execution time in seconds. `undefined` means no timeout. */ + timeout?: number | undefined + /** Working directory for execution. `undefined` means use the sandbox default. */ + cwd?: string | undefined + /** Abort signal to cancel execution. The process is killed when the signal fires. */ + signal?: AbortSignal | undefined +} + +/** + * Abstract execution environment. + * + * A Sandbox provides the runtime context where tools execute code, + * run commands, and interact with a filesystem. Multiple tools share + * the same Sandbox instance, giving them a common working directory + * and filesystem. + * + * Streaming methods (`executeStreaming`, `executeCodeStreaming`) are the abstract primitives. + * Non-streaming convenience methods (`execute`, `executeCode`) consume + * the stream and return the final result. + */ +export abstract class Sandbox { + /** + * Execute a shell command, streaming output. + * + * Yields {@link StreamChunk} objects for stdout and stderr as output + * arrives. The final yield is an {@link ExecutionResult} with the + * exit code and complete output. + * + * @param command - The shell command to execute. + * @param options - Execution options (timeout, cwd). + * @returns Async iterable yielding StreamChunks followed by a final ExecutionResult. + */ + abstract executeStreaming(command: string, options?: ExecuteOptions): AsyncIterable + + /** + * Execute source code via a language interpreter, streaming output. + * + * @param code - The source code to execute. + * @param language - The interpreter to use (e.g., `"python3"`, `"node"`). + * @param options - Execution options (timeout, cwd). + * @returns Async iterable yielding StreamChunks followed by a final ExecutionResult. + */ + abstract executeCodeStreaming( + code: string, + language: string, + options?: ExecuteOptions + ): AsyncIterable + + /** + * Read a file from the sandbox filesystem as raw bytes. + * + * Returns `Uint8Array` to support both text and binary files. + * Use {@link readText} for a convenience wrapper that decodes to a string. + * + * @param path - Path to the file to read. + * @returns The file contents as raw bytes. + * @throws Error if the file does not exist. + */ + abstract readFile(path: string): Promise + + /** + * Write raw bytes to a file in the sandbox filesystem. + * + * Implementations should create parent directories if they do not exist. + * Use {@link writeText} for a convenience wrapper that encodes a string. + * + * @param path - Path to the file to write. + * @param content - The content to write. + */ + abstract writeFile(path: string, content: Uint8Array): Promise + + /** + * Remove a file from the sandbox filesystem. + * + * @param path - Path to the file to remove. + * @throws Error if the file does not exist. + */ + abstract removeFile(path: string): Promise + + /** + * List files in a sandbox directory. + * + * Returns {@link FileInfo} entries with name, isDir, and size metadata. + * Fields `isDir` and `size` may be `undefined` if the backend cannot + * determine them. + * + * @param path - Path to the directory to list. + * @returns Array of FileInfo entries for the directory contents. + * @throws Error if the directory does not exist. + */ + abstract listFiles(path: string): Promise + + // ---- Non-streaming convenience methods ---- + + /** + * Execute a shell command and return the result. + * + * Consumes {@link executeStreaming} and returns the final {@link ExecutionResult}. + * Use `executeStreaming` when you need to process output as it arrives. + * + * @param command - The shell command to execute. + * @param options - Execution options (timeout, cwd). + * @returns The execution result with exit code and output. + */ + async execute(command: string, options?: ExecuteOptions): Promise { + for await (const chunk of this.executeStreaming(command, options)) { + if (chunk.type === 'executionResult') { + return chunk + } + } + throw new Error('executeStreaming() did not yield an ExecutionResult') + } + + /** + * Execute source code and return the result. + * + * Consumes {@link executeCodeStreaming} and returns the final {@link ExecutionResult}. + * Use `executeCodeStreaming` when you need to process output as it arrives. + * + * @param code - The source code to execute. + * @param language - The interpreter to use. + * @param options - Execution options (timeout, cwd). + * @returns The execution result with exit code and output. + */ + async executeCode(code: string, language: string, options?: ExecuteOptions): Promise { + for await (const chunk of this.executeCodeStreaming(code, language, options)) { + if (chunk.type === 'executionResult') { + return chunk + } + } + throw new Error('executeCodeStreaming() did not yield an ExecutionResult') + } + + /** + * Read a text file from the sandbox filesystem. + * + * Convenience wrapper over {@link readFile} that decodes bytes as UTF-8. + * For other encodings, call `readFile` and decode manually. + * + * @param path - Path to the file to read. + * @returns The file contents decoded as a UTF-8 string. + */ + async readText(path: string): Promise { + return new TextDecoder().decode(await this.readFile(path)) + } + + /** + * Write a text file to the sandbox filesystem. + * + * Convenience wrapper over {@link writeFile} that encodes a string as UTF-8. + * For other encodings, encode manually and call `writeFile`. + * + * @param path - Path to the file to write. + * @param content - The text content to write. + */ + async writeText(path: string, content: string): Promise { + await this.writeFile(path, new TextEncoder().encode(content)) + } +} diff --git a/strands-ts/src/sandbox/constants.ts b/strands-ts/src/sandbox/constants.ts new file mode 100644 index 0000000000..43c8b3641f --- /dev/null +++ b/strands-ts/src/sandbox/constants.ts @@ -0,0 +1,6 @@ +/** + * Regex pattern for validating language/interpreter names. + * Allows alphanumeric characters, dots, hyphens, and underscores. + * Rejects path separators, spaces, and shell metacharacters to prevent injection. + */ +export const LANGUAGE_PATTERN = /^[a-zA-Z0-9._-]+$/ diff --git a/strands-ts/src/sandbox/posix-shell.ts b/strands-ts/src/sandbox/posix-shell.ts new file mode 100644 index 0000000000..c0399b685f --- /dev/null +++ b/strands-ts/src/sandbox/posix-shell.ts @@ -0,0 +1,90 @@ +/** + * Shell sandbox with default implementations for file and code operations. + * + * Subclasses only need to implement {@link PosixShellSandbox.executeStreaming} — + * all other operations are implemented by running shell commands through it. + * Use this for remote environments where only shell access is available + * (Docker containers, SSH connections, cloud runtimes). + */ + +import { Sandbox } from './base.js' +import type { ExecuteOptions } from './base.js' +import { LANGUAGE_PATTERN } from './constants.js' +import type { ExecutionResult, FileInfo, StreamChunk } from './types.js' +import { shellQuote } from '../utils/shell-quote.js' + +/** + * Abstract sandbox that provides shell-based defaults for file and code operations. + * Assumes a POSIX-compatible shell (sh/bash) on the target. + * + * Subclasses only need to implement {@link executeStreaming}. The remaining + * operations — `executeCodeStreaming`, `readFile`, `writeFile`, `removeFile`, + * and `listFiles` — are implemented via shell commands piped through + * `executeStreaming`. + * + * Subclasses may override any method with a native implementation for + * better performance or to handle edge cases (e.g., binary-safe file + * transfer via Docker stdin pipes, or native API calls for cloud backends). + */ +export abstract class PosixShellSandbox extends Sandbox { + async *executeCodeStreaming( + code: string, + language: string, + options?: ExecuteOptions + ): AsyncGenerator { + if (!LANGUAGE_PATTERN.test(language)) { + throw new Error(`language parameter contains invalid characters: ${language}`) + } + const encoded = btoa(Array.from(new TextEncoder().encode(code), (b) => String.fromCharCode(b)).join('')) + const eof = `STRANDS_EOF_${crypto.randomUUID().slice(0, 16)}` + yield* this.executeStreaming(`base64 -d << '${eof}' | ${language}\n${encoded}\n${eof}`, options) + } + + async readFile(path: string): Promise { + const result = await this.execute(`base64 < ${shellQuote(path)}`) + if (result.exitCode !== 0) { + throw new Error(result.stderr || `Failed to read file: ${path}`) + } + return Uint8Array.from(atob(result.stdout.replace(/\s/g, '')), (c) => c.charCodeAt(0)) + } + + async writeFile(path: string, content: Uint8Array): Promise { + const encoded = btoa(Array.from(content, (b) => String.fromCharCode(b)).join('')) + const quoted = shellQuote(path) + const eof = `STRANDS_EOF_${crypto.randomUUID().slice(0, 16)}` + const cmd = `mkdir -p "$(dirname ${quoted})" && base64 -d << '${eof}' > ${quoted}\n${encoded}\n${eof}` + const result = await this.execute(cmd) + if (result.exitCode !== 0) { + throw new Error(result.stderr || `Failed to write file: ${path}`) + } + } + + async removeFile(path: string): Promise { + const result = await this.execute(`rm ${shellQuote(path)}`) + if (result.exitCode !== 0) { + throw new Error(result.stderr || `Failed to remove file: ${path}`) + } + } + + async listFiles(path: string): Promise { + const quoted = shellQuote(path) + const result = await this.execute(`test -d ${quoted} || exit 1; env QUOTING_STYLE=literal ls -1ap ${quoted}`) + if (result.exitCode !== 0) { + throw new Error(result.stderr || `Failed to list directory: ${path}`) + } + + const entries: FileInfo[] = [] + for (const raw of result.stdout.split('\n')) { + const line = raw.replace(/\r$/, '') + if (!line || line === './' || line === '../') { + continue + } + const isDir = line.endsWith('/') + const name = isDir ? line.slice(0, -1) : line + if (name) { + entries.push({ name, isDir }) + } + } + return entries + } +} diff --git a/strands-ts/src/sandbox/stream-process.ts b/strands-ts/src/sandbox/stream-process.ts new file mode 100644 index 0000000000..21d7784474 --- /dev/null +++ b/strands-ts/src/sandbox/stream-process.ts @@ -0,0 +1,181 @@ +/** + * Spawn a process and stream its stdout/stderr as an async generator. + */ + +import { spawn } from 'child_process' +import type { ExecutionResult, StreamChunk } from './types.js' + +const SIGNAL_CODES: Record = { + SIGHUP: 1, + SIGINT: 2, + SIGQUIT: 3, + SIGABRT: 6, + SIGKILL: 9, + SIGSEGV: 11, + SIGPIPE: 13, + SIGTERM: 15, +} + +/** + * Options for {@link streamProcess}. + */ +export interface StreamProcessOptions { + /** Maximum execution time in seconds. */ + timeout?: number | undefined + /** Abort signal to cancel execution. */ + signal?: AbortSignal | undefined + /** Custom error message when the spawned binary is not found (ENOENT). */ + enoentMessage?: string | undefined +} + +/** + * Spawn a command and stream its stdout/stderr, yielding the final result. + * + * Bridges Node.js event emitters to an async generator. Chunks are + * yielded incrementally as the process produces output. The final + * yield is an ExecutionResult with the exit code and complete output. + * + * All listeners are attached synchronously before any await to prevent + * missed events from fast-completing processes. + * + * @param command - The binary to spawn. + * @param args - Arguments to pass to the binary. + * @param options - Timeout, abort signal, and ENOENT handling options. + * @returns An async generator yielding StreamChunks followed by a final ExecutionResult. + */ +export async function* streamProcess( + command: string, + args: string[], + options?: StreamProcessOptions +): AsyncGenerator { + const proc = spawn(command, args) + const chunks: StreamChunk[] = [] + let stdout = '' + let stderr = '' + let done = false + let terminating = false + let exitCode = 0 + let error: Error | undefined + let enoent = false + let resolveWait: (() => void) | undefined + let timeoutHandle: ReturnType | undefined + let killTimer: ReturnType | undefined + + const wake = (): void => { + if (resolveWait) { + resolveWait() + resolveWait = undefined + } + } + + const terminate = (reason: Error): void => { + if (terminating || done) return + terminating = true + error = reason + proc.kill('SIGTERM') + wake() + killTimer = setTimeout(() => { + if (!done) proc.kill('SIGKILL') + }, 1000) + } + + proc.stdout?.on('data', (data) => { + const text = String(data) + stdout += text + chunks.push({ type: 'streamChunk', data: text, streamType: 'stdout' }) + wake() + }) + + proc.stderr?.on('data', (data) => { + const text = String(data) + stderr += text + chunks.push({ type: 'streamChunk', data: text, streamType: 'stderr' }) + wake() + }) + + proc.on('close', (code, signal) => { + if (!done) { + if (code !== null) { + exitCode = code + } else if (signal) { + exitCode = 128 + (SIGNAL_CODES[signal] ?? 1) + } else { + exitCode = 1 + } + done = true + wake() + } + }) + + proc.on('error', (err) => { + if (!done) { + if (options?.enoentMessage && 'code' in err && err.code === 'ENOENT') { + enoent = true + } else { + error = err + } + done = true + wake() + } + }) + + const onAbort = (): void => terminate(new Error('Execution aborted')) + + if (options?.signal) { + if (options.signal.aborted) { + onAbort() + } else { + options.signal.addEventListener('abort', onAbort, { once: true }) + } + } + + if (options?.timeout !== undefined) { + timeoutHandle = setTimeout(() => { + terminate(new Error(`Execution timed out after ${options.timeout} seconds`)) + }, options.timeout * 1000) + } + + try { + while (true) { + if (chunks.length > 0) { + const batch = chunks.splice(0, chunks.length) + for (const chunk of batch) { + yield chunk + } + } + + if (done || terminating) break + + await new Promise((resolve) => { + resolveWait = resolve + setTimeout(resolve, 50) + }) + } + + if (enoent) { + yield { + type: 'executionResult', + exitCode: 127, + stdout: '', + stderr: options!.enoentMessage!, + outputFiles: [], + } satisfies ExecutionResult + return + } + + if (error) throw error + + yield { + type: 'executionResult', + exitCode, + stdout, + stderr, + outputFiles: [], + } satisfies ExecutionResult + } finally { + if (timeoutHandle !== undefined) clearTimeout(timeoutHandle) + if (killTimer !== undefined) clearTimeout(killTimer) + if (options?.signal) options.signal.removeEventListener('abort', onAbort) + if (!done) proc.kill() + } +} diff --git a/strands-ts/src/sandbox/types.ts b/strands-ts/src/sandbox/types.ts new file mode 100644 index 0000000000..94c116e7ba --- /dev/null +++ b/strands-ts/src/sandbox/types.ts @@ -0,0 +1,61 @@ +/** + * Data types for the sandbox abstraction. + * + * These types represent the inputs and outputs of sandbox operations — + * execution results, file metadata, and streaming chunks. + */ + +/** + * Type of a streaming output chunk — distinguishes stdout from stderr. + */ +export type StreamType = 'stdout' | 'stderr' + +/** + * A typed chunk of streaming output from command or code execution. + * + * Allows consumers to distinguish stdout from stderr during streaming, + * enabling richer UIs and more precise output handling. + */ +export interface StreamChunk { + readonly type: 'streamChunk' + readonly data: string + readonly streamType: StreamType +} + +/** + * Metadata about a file or directory in a sandbox. + * + * Provides minimal structured information that lets tools distinguish + * files from directories and report sizes. `isDir` and `size` are + * `undefined` when the backend cannot determine them accurately. + */ +export interface FileInfo { + readonly name: string + readonly isDir?: boolean + readonly size?: number +} + +/** + * A file produced as output by code execution. + * + * Used to carry binary artifacts (images, charts, PDFs, compiled files) + * from sandbox execution back to the agent. Shell-based sandboxes + * typically return an empty array. Jupyter-backed or API-backed + * sandboxes can populate this with generated artifacts. + */ +export interface OutputFile { + readonly name: string + readonly content: Uint8Array + readonly mimeType: string +} + +/** + * Result of command or code execution in a sandbox. + */ +export interface ExecutionResult { + readonly type: 'executionResult' + readonly exitCode: number + readonly stdout: string + readonly stderr: string + readonly outputFiles: OutputFile[] +} diff --git a/strands-ts/src/session/__tests__/file-storage.test.node.ts b/strands-ts/src/session/__tests__/file-storage.test.node.ts new file mode 100644 index 0000000000..180a2da86f --- /dev/null +++ b/strands-ts/src/session/__tests__/file-storage.test.node.ts @@ -0,0 +1,399 @@ +import { describe, expect, it, beforeEach, afterEach, vi } from 'vitest' +import { promises as fs } from 'fs' +import { join } from 'path' +import { tmpdir } from 'os' +import { FileStorage } from '../file-storage.js' +import { SessionError } from '../../errors.js' +import { + createTestSnapshot, + createTestManifest, + createTestScope, + createTestSnapshots, +} from '../../__fixtures__/mock-storage-provider.js' +import type { SnapshotLocation } from '../storage.js' + +const SCOPE_ID = 'test-agent' + +describe('FileStorage', () => { + let storage: FileStorage + let testDir: string + + beforeEach(async () => { + testDir = join(tmpdir(), `file-storage-test-${Date.now()}-${Math.random().toString(36).slice(2)}`) + await fs.mkdir(testDir, { recursive: true }) + storage = new FileStorage(testDir) + }) + + afterEach(async () => { + try { + await fs.rm(testDir, { recursive: true, force: true }) + } catch { + // Ignore cleanup errors + } + }) + + describe('saveSnapshot', () => { + describe('FileSnapshotStorage_When_saveSnapshot_Then_CreatesFiles', () => { + it('saves snapshot to history file', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: createTestScope(), scopeId: SCOPE_ID } + const snapshot = createTestSnapshot() + + await storage.saveSnapshot({ location, snapshotId: '1', isLatest: false, snapshot }) + + const historyPath = join( + testDir, + location.sessionId, + 'scopes', + 'agent', + SCOPE_ID, + 'snapshots', + 'immutable_history', + 'snapshot_1.json' + ) + const content = await fs.readFile(historyPath, 'utf8') + expect(JSON.parse(content)).toEqual(snapshot) + }) + + it('saves snapshot as latest when isLatest is true', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: createTestScope(), scopeId: SCOPE_ID } + const snapshot = createTestSnapshot() + + await storage.saveSnapshot({ location, snapshotId: '1', isLatest: true, snapshot }) + + const latestPath = join( + testDir, + location.sessionId, + 'scopes', + 'agent', + SCOPE_ID, + 'snapshots', + 'snapshot_latest.json' + ) + const content = await fs.readFile(latestPath, 'utf8') + expect(JSON.parse(content)).toEqual(snapshot) + }) + + it('creates directories recursively', async () => { + const location: SnapshotLocation = { + sessionId: 'new-session', + scope: createTestScope('agent'), + scopeId: 'new-agent', + } + const snapshot = createTestSnapshot() + + await storage.saveSnapshot({ location, snapshotId: '1', isLatest: true, snapshot }) + + const expectedDir = join(testDir, location.sessionId, 'scopes', 'agent', location.scopeId, 'snapshots') + const stats = await fs.stat(expectedDir) + expect(stats.isDirectory()).toBe(true) + }) + }) + + describe('FileSnapshotStorage_When_saveSnapshotFails_Then_ThrowsSessionError', () => { + it('throws SessionError when write fails', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: createTestScope(), scopeId: SCOPE_ID } + const snapshot = createTestSnapshot() + + vi.spyOn(fs, 'writeFile').mockRejectedValueOnce(new Error('Write failed')) + + await expect(storage.saveSnapshot({ location, snapshotId: '1', isLatest: false, snapshot })).rejects.toThrow( + SessionError + ) + }) + }) + + describe('FileSnapshotStorage_When_MultiAgentScope_Then_SavesCorrectly', () => { + it('saves multi-agent snapshot to correct path', async () => { + const location: SnapshotLocation = { + sessionId: 'multi-session', + scope: createTestScope('multiAgent'), + scopeId: 'graph-1', + } + const snapshot = createTestSnapshot({ scope: 'multiAgent' }) + + await storage.saveSnapshot({ location, snapshotId: '1', isLatest: true, snapshot }) + + const expectedPath = join( + testDir, + location.sessionId, + 'scopes', + 'multiAgent', + location.scopeId, + 'snapshots', + 'snapshot_latest.json' + ) + const content = await fs.readFile(expectedPath, 'utf8') + expect(JSON.parse(content)).toEqual(snapshot) + }) + }) + }) + + describe('loadSnapshot', () => { + describe('FileSnapshotStorage_When_LoadLatestSnapshot_Then_ReturnsSnapshot', () => { + it('loads latest snapshot when snapshotId is undefined', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: createTestScope(), scopeId: SCOPE_ID } + const snapshot = createTestSnapshot() + await storage.saveSnapshot({ location, snapshotId: '1', isLatest: true, snapshot }) + + const result = await storage.loadSnapshot({ location }) + + expect(result).toEqual(snapshot) + }) + }) + + describe('FileSnapshotStorage_When_LoadSpecificSnapshot_Then_ReturnsSnapshot', () => { + it('loads specific snapshot by ID', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: createTestScope(), scopeId: SCOPE_ID } + const snapshot = createTestSnapshot() + await storage.saveSnapshot({ location, snapshotId: '5', isLatest: false, snapshot }) + + const result = await storage.loadSnapshot({ location, snapshotId: '5' }) + + expect(result).toEqual(snapshot) + }) + }) + + describe('FileSnapshotStorage_When_SnapshotNotFound_Then_ReturnsNull', () => { + it('returns null when snapshot file does not exist', async () => { + const result = await storage.loadSnapshot({ + location: { sessionId: 'nonexistent', scope: 'agent', scopeId: SCOPE_ID }, + }) + expect(result).toBeNull() + }) + }) + + describe('FileSnapshotStorage_When_InvalidJSON_Then_ThrowsSessionError', () => { + it('throws SessionError when JSON is invalid', async () => { + const sessionId = 'test-session' + const filePath = join(testDir, sessionId, 'scopes', 'agent', SCOPE_ID, 'snapshots', 'snapshot_latest.json') + + await fs.mkdir(join(testDir, sessionId, 'scopes', 'agent', SCOPE_ID, 'snapshots'), { recursive: true }) + await fs.writeFile(filePath, 'invalid json', 'utf8') + + await expect( + storage.loadSnapshot({ location: { sessionId, scope: 'agent', scopeId: SCOPE_ID } }) + ).rejects.toThrow(SessionError) + }) + }) + + describe('FileSnapshotStorage_When_ReadError_Then_ThrowsSessionError', () => { + it('throws SessionError when file read fails', async () => { + vi.spyOn(fs, 'readFile').mockRejectedValueOnce(new Error('Permission denied')) + await expect( + storage.loadSnapshot({ location: { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID } }) + ).rejects.toThrow(SessionError) + }) + }) + }) + + describe('listSnapshotIds', () => { + describe('FileSnapshotStorage_When_listSnapshots_Then_ReturnsOrderedIds', () => { + it('returns sorted snapshot IDs', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: createTestScope(), scopeId: SCOPE_ID } + const snapshots = createTestSnapshots(3) + const ids = [ + '019c9bf1-14e5-7eef-96fb-cc07ae54210f', + '019c9bf1-1d34-7eef-96fb-d1be20fd7bbd', + '019c9bf1-24bb-7eef-96fb-ddcc943cd859', + ] + + await storage.saveSnapshot({ location, snapshotId: ids[2]!, isLatest: false, snapshot: snapshots[2]! }) + await storage.saveSnapshot({ location, snapshotId: ids[0]!, isLatest: false, snapshot: snapshots[0]! }) + await storage.saveSnapshot({ location, snapshotId: ids[1]!, isLatest: false, snapshot: snapshots[1]! }) + + const result = await storage.listSnapshotIds({ location }) + + expect(result).toEqual(ids) + }) + + it('ignores non-snapshot files', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: createTestScope(), scopeId: SCOPE_ID } + const snapshot = createTestSnapshot() + const id = '019c9bf1-14e5-7eef-96fb-cc07ae54210f' + await storage.saveSnapshot({ location, snapshotId: id, isLatest: false, snapshot }) + + const historyDir = join( + testDir, + location.sessionId, + 'scopes', + 'agent', + SCOPE_ID, + 'snapshots', + 'immutable_history' + ) + await fs.writeFile(join(historyDir, 'other-file.txt'), 'not a snapshot', 'utf8') + + const result = await storage.listSnapshotIds({ location }) + expect(result).toEqual([id]) + }) + + it('filters by startAfter for pagination', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: createTestScope(), scopeId: SCOPE_ID } + const snapshots = createTestSnapshots(3) + const ids = [ + '019c9bf1-14e5-7eef-96fb-cc07ae54210f', + '019c9bf1-1d34-7eef-96fb-d1be20fd7bbd', + '019c9bf1-24bb-7eef-96fb-ddcc943cd859', + ] + for (let i = 0; i < ids.length; i++) { + await storage.saveSnapshot({ location, snapshotId: ids[i]!, isLatest: false, snapshot: snapshots[i]! }) + } + + const result = await storage.listSnapshotIds({ location, startAfter: ids[0]! }) + + expect(result).toEqual([ids[1], ids[2]]) + }) + + it('limits results when limit is provided', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: createTestScope(), scopeId: SCOPE_ID } + const snapshots = createTestSnapshots(3) + const ids = [ + '019c9bf1-14e5-7eef-96fb-cc07ae54210f', + '019c9bf1-1d34-7eef-96fb-d1be20fd7bbd', + '019c9bf1-24bb-7eef-96fb-ddcc943cd859', + ] + for (let i = 0; i < ids.length; i++) { + await storage.saveSnapshot({ location, snapshotId: ids[i]!, isLatest: false, snapshot: snapshots[i]! }) + } + + const result = await storage.listSnapshotIds({ location, limit: 2 }) + + expect(result).toEqual([ids[0], ids[1]]) + }) + + it('combines startAfter and limit', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: createTestScope(), scopeId: SCOPE_ID } + const snapshots = createTestSnapshots(3) + const ids = [ + '019c9bf1-14e5-7eef-96fb-cc07ae54210f', + '019c9bf1-1d34-7eef-96fb-d1be20fd7bbd', + '019c9bf1-24bb-7eef-96fb-ddcc943cd859', + ] + for (let i = 0; i < ids.length; i++) { + await storage.saveSnapshot({ location, snapshotId: ids[i]!, isLatest: false, snapshot: snapshots[i]! }) + } + + const result = await storage.listSnapshotIds({ location, startAfter: ids[0]!, limit: 1 }) + + expect(result).toEqual([ids[1]]) + }) + }) + + describe('FileSnapshotStorage_When_DirectoryNotFound_Then_ReturnsEmptyArray', () => { + it('returns empty array when directory does not exist', async () => { + const result = await storage.listSnapshotIds({ + location: { sessionId: 'nonexistent', scope: 'agent', scopeId: SCOPE_ID }, + }) + expect(result).toEqual([]) + }) + }) + + describe('FileSnapshotStorage_When_ReadDirFails_Then_ThrowsSessionError', () => { + it('throws SessionError when readdir fails with non-ENOENT error', async () => { + vi.spyOn(fs, 'readdir').mockRejectedValueOnce(new Error('Permission denied')) + await expect( + storage.listSnapshotIds({ location: { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID } }) + ).rejects.toThrow(SessionError) + }) + }) + }) + + describe('deleteSession', () => { + describe('FileSnapshotStorage_When_DeleteSession_Then_RemovesDirectory', () => { + it('removes the entire session directory', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: createTestScope(), scopeId: SCOPE_ID } + await storage.saveSnapshot({ location, snapshotId: '1', isLatest: true, snapshot: createTestSnapshot() }) + + await storage.deleteSession({ sessionId: 'test-session' }) + + await expect(fs.stat(join(testDir, 'test-session'))).rejects.toMatchObject({ code: 'ENOENT' }) + }) + + it('no-ops when session directory does not exist', async () => { + await expect(storage.deleteSession({ sessionId: 'nonexistent-session' })).resolves.toBeUndefined() + }) + }) + + describe('FileSnapshotStorage_When_DeleteSessionFails_Then_ThrowsSessionError', () => { + it('throws SessionError when rm fails', async () => { + vi.spyOn(fs, 'rm').mockRejectedValueOnce(new Error('Permission denied')) + await expect(storage.deleteSession({ sessionId: 'test-session' })).rejects.toThrow(SessionError) + }) + }) + }) + + describe('saveManifest', () => { + describe('FileSnapshotStorage_When_SaveManifest_Then_CreatesFile', () => { + it('saves manifest to correct path', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID } + const manifest = createTestManifest() + + await storage.saveManifest({ location, manifest }) + + const manifestPath = join( + testDir, + location.sessionId, + 'scopes', + 'agent', + SCOPE_ID, + 'snapshots', + 'manifest.json' + ) + const content = await fs.readFile(manifestPath, 'utf8') + expect(JSON.parse(content)).toEqual(manifest) + }) + }) + + describe('FileSnapshotStorage_When_SaveManifestFails_Then_ThrowsSessionError', () => { + it('throws SessionError when write fails', async () => { + vi.spyOn(fs, 'writeFile').mockRejectedValueOnce(new Error('Write failed')) + await expect( + storage.saveManifest({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID }, + manifest: createTestManifest(), + }) + ).rejects.toThrow(SessionError) + }) + }) + }) + + describe('loadManifest', () => { + describe('FileSnapshotStorage_When_LoadManifest_Then_ReturnsManifest', () => { + it('loads manifest from file', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID } + const manifest = createTestManifest() + await storage.saveManifest({ location, manifest }) + + const result = await storage.loadManifest({ location }) + + expect(result).toEqual(manifest) + }) + }) + + describe('FileSnapshotStorage_When_ManifestNotFound_Then_ReturnsDefault', () => { + it('returns default manifest when manifest file does not exist', async () => { + const result = await storage.loadManifest({ + location: { sessionId: 'nonexistent', scope: 'agent', scopeId: SCOPE_ID }, + }) + expect(result).toEqual({ + schemaVersion: '1.0', + updatedAt: expect.any(String), + }) + }) + }) + + describe('FileSnapshotStorage_When_InvalidManifestJSON_Then_ThrowsSessionError', () => { + it('throws SessionError when JSON is invalid', async () => { + const sessionId = 'test-session' + const filePath = join(testDir, sessionId, 'scopes', 'agent', SCOPE_ID, 'snapshots', 'manifest.json') + + await fs.mkdir(join(testDir, sessionId, 'scopes', 'agent', SCOPE_ID, 'snapshots'), { recursive: true }) + await fs.writeFile(filePath, 'invalid json', 'utf8') + + await expect( + storage.loadManifest({ location: { sessionId, scope: 'agent', scopeId: SCOPE_ID } }) + ).rejects.toThrow(SessionError) + }) + }) + }) +}) diff --git a/strands-ts/src/session/__tests__/s3-storage.test.ts b/strands-ts/src/session/__tests__/s3-storage.test.ts new file mode 100644 index 0000000000..8c0bfcb3dc --- /dev/null +++ b/strands-ts/src/session/__tests__/s3-storage.test.ts @@ -0,0 +1,574 @@ +import { describe, expect, it, beforeEach, vi, type MockedFunction } from 'vitest' +import { S3Storage } from '../s3-storage.js' +import { SessionError } from '../../errors.js' +import { createTestSnapshot, createTestManifest, createTestScope } from '../../__fixtures__/mock-storage-provider.js' +import type { SnapshotLocation } from '../storage.js' + +vi.mock('@aws-sdk/client-s3', () => ({ + S3Client: vi.fn().mockImplementation(function () { + return { + send: vi.fn(), + config: {}, + } + }), + PutObjectCommand: vi.fn().mockImplementation(function (input) { + return { input } + }), + GetObjectCommand: vi.fn().mockImplementation(function (input) { + return { input } + }), + ListObjectsV2Command: vi.fn().mockImplementation(function (input) { + return { input } + }), + DeleteObjectsCommand: vi.fn().mockImplementation(function (input) { + return { input } + }), +})) + +const SCOPE_ID = 'test-agent' + +describe('S3Storage', () => { + let storage: S3Storage + let mockS3Client: { send: MockedFunction } + + beforeEach(() => { + vi.clearAllMocks() + storage = new S3Storage({ bucket: 'test-bucket', region: 'us-east-1' }) + mockS3Client = (storage as any)._s3 + }) + + describe('constructor', () => { + describe('S3SnapshotStorage_When_ValidConfig_Then_CreatesInstance', () => { + it('stores bucket and region configuration', () => { + const instance = new S3Storage({ bucket: 'test-bucket', region: 'us-west-2' }) + expect((instance as any)._bucket).toBe('test-bucket') + expect((instance as any)._s3).toBeDefined() + }) + + it('stores prefix when provided', () => { + const instance = new S3Storage({ bucket: 'test-bucket', prefix: 'my-prefix', region: 'us-east-1' }) + expect((instance as any)._prefix).toBe('my-prefix') + }) + + it('uses provided S3 client instead of creating new one', () => { + const customClient = { send: vi.fn() } + const instance = new S3Storage({ bucket: 'test-bucket', s3Client: customClient as any }) + expect((instance as any)._s3).toBe(customClient) + }) + + it('throws error when both s3Client and region are provided', () => { + const config = { bucket: 'test-bucket', region: 'us-west-2', s3Client: { send: vi.fn() } as any } + expect(() => new S3Storage(config)).toThrow(SessionError) + expect(() => new S3Storage(config)).toThrow('Cannot specify both s3Client and region') + }) + }) + }) + + describe('saveSnapshot', () => { + describe('S3SnapshotStorage_When_saveSnapshot_Then_PutsObjects', () => { + it('saves snapshot to S3 history', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: createTestScope(), scopeId: SCOPE_ID } + const snapshot = createTestSnapshot() + mockS3Client.send.mockResolvedValue({}) + + await storage.saveSnapshot({ location, snapshotId: '1', isLatest: false, snapshot }) + + expect(mockS3Client.send).toHaveBeenCalledWith( + expect.objectContaining({ + input: { + Bucket: 'test-bucket', + Key: `test-session/scopes/agent/${SCOPE_ID}/snapshots/immutable_history/snapshot_1.json`, + Body: JSON.stringify(snapshot, null, 2), + ContentType: 'application/json', + }, + }) + ) + }) + + it('saves snapshot as latest when isLatest is true', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: createTestScope(), scopeId: SCOPE_ID } + const snapshot = createTestSnapshot() + mockS3Client.send.mockResolvedValue({}) + + await storage.saveSnapshot({ location, snapshotId: '1', isLatest: true, snapshot }) + + expect(mockS3Client.send).toHaveBeenCalledWith( + expect.objectContaining({ + input: expect.objectContaining({ + Key: `test-session/scopes/agent/${SCOPE_ID}/snapshots/snapshot_latest.json`, + }), + }) + ) + }) + + it('uses prefix when configured', async () => { + const storageWithPrefix = new S3Storage({ bucket: 'test-bucket', prefix: 'my-app', region: 'us-east-1' }) + const mockPrefixS3Client = (storageWithPrefix as any)._s3 + const location: SnapshotLocation = { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID } + const snapshot = createTestSnapshot() + mockPrefixS3Client.send.mockResolvedValue({}) + + await storageWithPrefix.saveSnapshot({ location, snapshotId: '1', isLatest: false, snapshot }) + + expect(mockPrefixS3Client.send).toHaveBeenCalledWith( + expect.objectContaining({ + input: expect.objectContaining({ + Key: `my-app/test-session/scopes/agent/${SCOPE_ID}/snapshots/immutable_history/snapshot_1.json`, + }), + }) + ) + }) + }) + + describe('S3SnapshotStorage_When_saveSnapshotFails_Then_ThrowsSessionError', () => { + it('throws SessionError when S3 put fails', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID } + const snapshot = createTestSnapshot() + mockS3Client.send.mockRejectedValue(new Error('S3 error')) + + await expect(storage.saveSnapshot({ location, snapshotId: '1', isLatest: false, snapshot })).rejects.toThrow( + 'Failed to write S3 object' + ) + }) + }) + + describe('S3SnapshotStorage_When_MultiAgentScope_Then_SavesCorrectly', () => { + it('saves multi-agent snapshot to correct S3 key', async () => { + const location: SnapshotLocation = { sessionId: 'multi-session', scope: 'multiAgent', scopeId: 'graph-1' } + const snapshot = createTestSnapshot({ scope: 'multiAgent' }) + mockS3Client.send.mockResolvedValue({}) + + await storage.saveSnapshot({ location, snapshotId: '1', isLatest: true, snapshot }) + + expect(mockS3Client.send).toHaveBeenCalledWith( + expect.objectContaining({ + input: expect.objectContaining({ + Key: 'multi-session/scopes/multiAgent/graph-1/snapshots/snapshot_latest.json', + }), + }) + ) + }) + }) + }) + + describe('loadSnapshot', () => { + describe('S3SnapshotStorage_When_LoadLatestSnapshot_Then_ReturnsSnapshot', () => { + it('loads latest snapshot when snapshotId is undefined', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID } + const snapshot = createTestSnapshot() + mockS3Client.send.mockResolvedValue({ + Body: { transformToString: () => Promise.resolve(JSON.stringify(snapshot)) }, + }) + + const result = await storage.loadSnapshot({ location }) + + expect(result).toEqual(snapshot) + expect(mockS3Client.send).toHaveBeenCalledWith( + expect.objectContaining({ + input: { + Bucket: 'test-bucket', + Key: `test-session/scopes/agent/${SCOPE_ID}/snapshots/snapshot_latest.json`, + }, + }) + ) + }) + }) + + describe('S3SnapshotStorage_When_LoadSpecificSnapshot_Then_ReturnsSnapshot', () => { + it('loads specific snapshot by ID', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID } + const snapshot = createTestSnapshot() + mockS3Client.send.mockResolvedValue({ + Body: { transformToString: () => Promise.resolve(JSON.stringify(snapshot)) }, + }) + + const result = await storage.loadSnapshot({ location, snapshotId: '5' }) + + expect(result).toEqual(snapshot) + expect(mockS3Client.send).toHaveBeenCalledWith( + expect.objectContaining({ + input: expect.objectContaining({ + Key: `test-session/scopes/agent/${SCOPE_ID}/snapshots/immutable_history/snapshot_5.json`, + }), + }) + ) + }) + }) + + describe('S3SnapshotStorage_When_SnapshotNotFound_Then_ReturnsNull', () => { + it('returns null when S3 object does not exist', async () => { + const noSuchKeyError = Object.assign(new Error('NoSuchKey'), { name: 'NoSuchKey' }) + mockS3Client.send.mockRejectedValue(noSuchKeyError) + + const result = await storage.loadSnapshot({ + location: { sessionId: 'nonexistent', scope: 'agent', scopeId: SCOPE_ID }, + }) + expect(result).toBeNull() + }) + + it('returns null when S3 response has no body', async () => { + mockS3Client.send.mockResolvedValue({ Body: null }) + const result = await storage.loadSnapshot({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID }, + }) + expect(result).toBeNull() + }) + + it('returns null when S3 response body is empty', async () => { + mockS3Client.send.mockResolvedValue({ Body: { transformToString: () => Promise.resolve('') } }) + const result = await storage.loadSnapshot({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID }, + }) + expect(result).toBeNull() + }) + }) + + describe('S3SnapshotStorage_When_InvalidJSON_Then_ThrowsSessionError', () => { + it('throws SessionError when JSON is invalid', async () => { + mockS3Client.send.mockResolvedValue({ + Body: { transformToString: () => Promise.resolve('invalid json') }, + }) + await expect( + storage.loadSnapshot({ location: { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID } }) + ).rejects.toThrow('Invalid JSON in S3 object') + }) + }) + + describe('S3SnapshotStorage_When_S3Error_Then_ThrowsSessionError', () => { + it('throws SessionError when S3 get fails', async () => { + mockS3Client.send.mockRejectedValue(new Error('S3 error')) + await expect( + storage.loadSnapshot({ location: { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID } }) + ).rejects.toThrow('S3 error reading') + }) + }) + }) + + describe('listSnapshotIds', () => { + describe('S3SnapshotStorage_When_listSnapshots_Then_ReturnsOrderedIds', () => { + it('returns sorted snapshot IDs', async () => { + const ids = [ + '019c9bf1-14e5-7eef-96fb-cc07ae54210f', + '019c9bf1-1d34-7eef-96fb-d1be20fd7bbd', + '019c9bf1-24bb-7eef-96fb-ddcc943cd859', + ] + // S3 returns objects in lexicographic key order — mock reflects that contract + mockS3Client.send.mockResolvedValue({ + Contents: [ + { Key: `test-session/scopes/agent/${SCOPE_ID}/snapshots/immutable_history/snapshot_${ids[0]}.json` }, + { Key: `test-session/scopes/agent/${SCOPE_ID}/snapshots/immutable_history/snapshot_${ids[1]}.json` }, + { Key: `test-session/scopes/agent/${SCOPE_ID}/snapshots/immutable_history/snapshot_${ids[2]}.json` }, + ], + }) + + const result = await storage.listSnapshotIds({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID }, + }) + + expect(result).toEqual(ids) + expect(mockS3Client.send).toHaveBeenCalledWith( + expect.objectContaining({ + input: expect.objectContaining({ + Bucket: 'test-bucket', + Prefix: `test-session/scopes/agent/${SCOPE_ID}/snapshots/immutable_history/`, + MaxKeys: 1000, + }), + }) + ) + }) + + it('returns empty array when no objects exist', async () => { + mockS3Client.send.mockResolvedValue({ Contents: [] }) + const result = await storage.listSnapshotIds({ + location: { sessionId: 'empty-session', scope: 'agent', scopeId: SCOPE_ID }, + }) + expect(result).toEqual([]) + }) + + it('ignores non-snapshot objects', async () => { + const id1 = '019c9bf1-14e5-7eef-96fb-cc07ae54210f' + const id2 = '019c9bf1-1d34-7eef-96fb-d1be20fd7bbd' + mockS3Client.send.mockResolvedValue({ + Contents: [ + { Key: `test-session/scopes/agent/${SCOPE_ID}/snapshots/immutable_history/snapshot_${id1}.json` }, + { Key: `test-session/scopes/agent/${SCOPE_ID}/snapshots/immutable_history/other-file.txt` }, + { Key: `test-session/scopes/agent/${SCOPE_ID}/snapshots/immutable_history/snapshot_${id2}.json` }, + ], + }) + const result = await storage.listSnapshotIds({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID }, + }) + expect(result).toEqual([id1, id2]) + }) + + it('handles objects without Key property', async () => { + const id1 = '019c9bf1-14e5-7eef-96fb-cc07ae54210f' + const id2 = '019c9bf1-1d34-7eef-96fb-d1be20fd7bbd' + mockS3Client.send.mockResolvedValue({ + Contents: [ + { Key: `test-session/scopes/agent/${SCOPE_ID}/snapshots/immutable_history/snapshot_${id1}.json` }, + {}, + { Key: `test-session/scopes/agent/${SCOPE_ID}/snapshots/immutable_history/snapshot_${id2}.json` }, + ], + }) + const result = await storage.listSnapshotIds({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID }, + }) + expect(result).toEqual([id1, id2]) + }) + + it('filters by startAfter for pagination', async () => { + const ids = [ + '019c9bf1-14e5-7eef-96fb-cc07ae54210f', + '019c9bf1-1d34-7eef-96fb-d1be20fd7bbd', + '019c9bf1-24bb-7eef-96fb-ddcc943cd859', + ] + // Simulate S3 server-side StartAfter: only return objects after ids[0] + mockS3Client.send.mockResolvedValue({ + Contents: [ids[1], ids[2]].map((id) => ({ + Key: `test-session/scopes/agent/${SCOPE_ID}/snapshots/immutable_history/snapshot_${id}.json`, + })), + }) + + const result = await storage.listSnapshotIds({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID }, + startAfter: ids[0]!, + }) + + expect(result).toEqual([ids[1], ids[2]]) + expect(mockS3Client.send).toHaveBeenCalledWith( + expect.objectContaining({ + input: expect.objectContaining({ + StartAfter: `test-session/scopes/agent/${SCOPE_ID}/snapshots/immutable_history/snapshot_${ids[0]}.json`, + }), + }) + ) + }) + + it('limits results when limit is provided', async () => { + const ids = [ + '019c9bf1-14e5-7eef-96fb-cc07ae54210f', + '019c9bf1-1d34-7eef-96fb-d1be20fd7bbd', + '019c9bf1-24bb-7eef-96fb-ddcc943cd859', + ] + mockS3Client.send.mockResolvedValue({ + Contents: ids.map((id) => ({ + Key: `test-session/scopes/agent/${SCOPE_ID}/snapshots/immutable_history/snapshot_${id}.json`, + })), + }) + + const result = await storage.listSnapshotIds({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID }, + limit: 2, + }) + + expect(result).toEqual([ids[0], ids[1]]) + }) + + it('combines startAfter and limit', async () => { + const ids = [ + '019c9bf1-14e5-7eef-96fb-cc07ae54210f', + '019c9bf1-1d34-7eef-96fb-d1be20fd7bbd', + '019c9bf1-24bb-7eef-96fb-ddcc943cd859', + ] + // Simulate S3 server-side StartAfter: only return objects after ids[0] + mockS3Client.send.mockResolvedValue({ + Contents: [ids[1], ids[2]].map((id) => ({ + Key: `test-session/scopes/agent/${SCOPE_ID}/snapshots/immutable_history/snapshot_${id}.json`, + })), + }) + + const result = await storage.listSnapshotIds({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID }, + startAfter: ids[0]!, + limit: 1, + }) + + expect(result).toEqual([ids[1]]) + }) + }) + + describe('S3SnapshotStorage_When_ListObjectsFails_Then_ThrowsSessionError', () => { + it('throws SessionError when S3 list fails', async () => { + mockS3Client.send.mockRejectedValue(new Error('S3 list error')) + await expect( + storage.listSnapshotIds({ location: { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID } }) + ).rejects.toThrow('Failed to list snapshots for session test-session') + }) + }) + }) + + describe('deleteSession', () => { + describe('S3SnapshotStorage_When_DeleteSession_Then_DeletesAllObjects', () => { + it('deletes all objects under the session prefix', async () => { + mockS3Client.send + .mockResolvedValueOnce({ + Contents: [ + { Key: 'test-session/scopes/agent/agent-1/snapshots/snapshot_latest.json' }, + { Key: 'test-session/scopes/agent/agent-1/snapshots/immutable_history/snapshot_abc.json' }, + ], + IsTruncated: false, + }) + .mockResolvedValueOnce({}) + + await storage.deleteSession({ sessionId: 'test-session' }) + + expect(mockS3Client.send).toHaveBeenCalledWith( + expect.objectContaining({ + input: expect.objectContaining({ + Bucket: 'test-bucket', + Prefix: 'test-session/', + }), + }) + ) + expect(mockS3Client.send).toHaveBeenCalledWith( + expect.objectContaining({ + input: { + Bucket: 'test-bucket', + Delete: { + Objects: [ + { Key: 'test-session/scopes/agent/agent-1/snapshots/snapshot_latest.json' }, + { Key: 'test-session/scopes/agent/agent-1/snapshots/immutable_history/snapshot_abc.json' }, + ], + }, + }, + }) + ) + }) + + it('paginates when session has more than 1000 objects', async () => { + mockS3Client.send + .mockResolvedValueOnce({ + Contents: [{ Key: 'test-session/page-1-object.json' }], + IsTruncated: true, + NextContinuationToken: 'token-1', + }) + .mockResolvedValueOnce({}) + .mockResolvedValueOnce({ + Contents: [{ Key: 'test-session/page-2-object.json' }], + IsTruncated: false, + }) + .mockResolvedValueOnce({}) + + await storage.deleteSession({ sessionId: 'test-session' }) + + expect(mockS3Client.send).toHaveBeenCalledTimes(4) + }) + + it('no-ops when session has no objects', async () => { + mockS3Client.send.mockResolvedValueOnce({ Contents: [], IsTruncated: false }) + + await storage.deleteSession({ sessionId: 'empty-session' }) + + expect(mockS3Client.send).toHaveBeenCalledTimes(1) + }) + + it('uses prefix when configured', async () => { + const storageWithPrefix = new S3Storage({ bucket: 'test-bucket', prefix: 'my-app', region: 'us-east-1' }) + const mockPrefixS3Client = (storageWithPrefix as any)._s3 + mockPrefixS3Client.send.mockResolvedValueOnce({ Contents: [], IsTruncated: false }) + + await storageWithPrefix.deleteSession({ sessionId: 'test-session' }) + + expect(mockPrefixS3Client.send).toHaveBeenCalledWith( + expect.objectContaining({ + input: expect.objectContaining({ Prefix: 'my-app/test-session/' }), + }) + ) + }) + }) + + describe('S3SnapshotStorage_When_DeleteSessionFails_Then_ThrowsSessionError', () => { + it('throws SessionError when S3 list fails during delete', async () => { + mockS3Client.send.mockRejectedValue(new Error('S3 error')) + await expect(storage.deleteSession({ sessionId: 'test-session' })).rejects.toThrow( + 'Failed to delete session test-session' + ) + }) + }) + }) + + describe('loadManifest', () => { + describe('S3SnapshotStorage_When_LoadManifest_Then_ReturnsManifest', () => { + it('loads existing manifest', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID } + const manifest = createTestManifest() + mockS3Client.send.mockResolvedValue({ + Body: { transformToString: () => Promise.resolve(JSON.stringify(manifest)) }, + }) + + const result = await storage.loadManifest({ location }) + + expect(result).toEqual(manifest) + expect(mockS3Client.send).toHaveBeenCalledWith( + expect.objectContaining({ + input: expect.objectContaining({ + Key: `test-session/scopes/agent/${SCOPE_ID}/snapshots/manifest.json`, + }), + }) + ) + }) + }) + + describe('S3SnapshotStorage_When_ManifestNotFound_Then_ReturnsDefault', () => { + it('returns default manifest when S3 object does not exist', async () => { + const noSuchKeyError = Object.assign(new Error('NoSuchKey'), { name: 'NoSuchKey' }) + mockS3Client.send.mockRejectedValue(noSuchKeyError) + + const result = await storage.loadManifest({ + location: { sessionId: 'nonexistent', scope: 'agent', scopeId: SCOPE_ID }, + }) + expect(result).toEqual({ + schemaVersion: '1.0', + updatedAt: expect.any(String), + }) + }) + }) + + describe('S3SnapshotStorage_When_InvalidManifestJSON_Then_ThrowsSessionError', () => { + it('throws SessionError when manifest JSON is invalid', async () => { + mockS3Client.send.mockResolvedValue({ + Body: { transformToString: () => Promise.resolve('invalid json') }, + }) + await expect( + storage.loadManifest({ location: { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID } }) + ).rejects.toThrow(SessionError) + }) + }) + }) + + describe('saveManifest', () => { + describe('S3SnapshotStorage_When_SaveManifest_Then_PutsObject', () => { + it('saves manifest to S3', async () => { + const location: SnapshotLocation = { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID } + const manifest = createTestManifest() + mockS3Client.send.mockResolvedValue({}) + + await storage.saveManifest({ location, manifest }) + + expect(mockS3Client.send).toHaveBeenCalledWith( + expect.objectContaining({ + input: { + Bucket: 'test-bucket', + Key: `test-session/scopes/agent/${SCOPE_ID}/snapshots/manifest.json`, + Body: JSON.stringify(manifest, null, 2), + ContentType: 'application/json', + }, + }) + ) + }) + }) + + describe('S3SnapshotStorage_When_SaveManifestFails_Then_ThrowsSessionError', () => { + it('throws SessionError when S3 put fails', async () => { + mockS3Client.send.mockRejectedValue(new Error('S3 error')) + await expect( + storage.saveManifest({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: SCOPE_ID }, + manifest: createTestManifest(), + }) + ).rejects.toThrow(SessionError) + }) + }) + }) +}) diff --git a/strands-ts/src/session/__tests__/session-manager.test.ts b/strands-ts/src/session/__tests__/session-manager.test.ts new file mode 100644 index 0000000000..0d128fe178 --- /dev/null +++ b/strands-ts/src/session/__tests__/session-manager.test.ts @@ -0,0 +1,1083 @@ +import { describe, expect, it, beforeEach, vi } from 'vitest' +import { SessionManager } from '../session-manager.js' +import { MockSnapshotStorage, createTestSnapshot } from '../../__fixtures__/mock-storage-provider.js' +import { + InitializedEvent, + MessageAddedEvent, + AfterInvocationEvent, + AfterModelCallEvent, + HookableEvent, + type HookableEventConstructor, + type HookCallback, + type HookCleanup, +} from '../../hooks/index.js' +import { Agent } from '../../agent/agent.js' +import { Message, TextBlock } from '../../types/messages.js' +import { + createMockAgent as createMockAgentWithHooks, + invokeTrackedHook, + type TrackedHook, +} from '../../__fixtures__/agent-helpers.js' +import { loadStateFromJSONSymbol, stateToJSONSymbol } from '../../types/serializable.js' +import { StateStore } from '../../state-store.js' +import { logger } from '../../logging/logger.js' +import { + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + Graph, + type MultiAgent, + MultiAgentState, + NodeResult, + Status, +} from '../../multiagent/index.js' +import { takeSnapshot, loadSnapshot } from '../../agent/snapshot.js' +import type { Snapshot } from '../../types/snapshot.js' +import type { TakeSnapshotOptions } from '../../agent/snapshot.js' + +// Test fixtures +function createMockAgent(id = 'agent'): Agent { + const agent = { + id, + messages: [], + appState: { + _m: new Map(), + get(k: string) { + return this._m.get(k) + }, + set(k: string, v: unknown) { + this._m.set(k, v) + }, + [stateToJSONSymbol]() { + return Object.fromEntries(this._m) + }, + [loadStateFromJSONSymbol](json: Record) { + Object.entries(json).forEach(([k, v]) => this._m.set(k, v)) + }, + } as any, + modelState: new StateStore(), + systemPrompt: 'Test prompt', + takeSnapshot(options: TakeSnapshotOptions): Snapshot { + return takeSnapshot(agent as any, options) + }, + loadSnapshot(snapshot: Snapshot): void { + loadSnapshot(agent as any, snapshot) + }, + } as unknown as Agent + return agent +} + +const MOCK_MESSAGE = new Message({ role: 'user', content: [new TextBlock('test')] }) + +function createMockEvent(agent: Agent) { + return { agent, invocationState: {} } +} + +function createMockMessageEvent(agent: Agent) { + return { agent, message: MOCK_MESSAGE, invocationState: {} } +} + +async function initPluginAndInvokeHook( + sessionManager: SessionManager, + event: T +): Promise { + const pluginAgent = createMockAgentWithHooks() + sessionManager.initAgent(pluginAgent) + await invokeTrackedHook(pluginAgent, event) +} + +describe('SessionManager', () => { + let storage: MockSnapshotStorage + let sessionManager: SessionManager + let mockAgent: Agent + + beforeEach(() => { + storage = new MockSnapshotStorage() + mockAgent = createMockAgent() + }) + + describe('constructor', () => { + it('defaults saveLatestOn to invocation', async () => { + sessionManager = new SessionManager({ sessionId: 'test-default', storage: { snapshot: storage } }) + + await initPluginAndInvokeHook(sessionManager, new AfterInvocationEvent(createMockEvent(mockAgent))) + + const snapshot = await storage.loadSnapshot({ + location: { sessionId: 'test-default', scope: 'agent', scopeId: 'agent' }, + }) + expect(snapshot).not.toBeNull() + }) + }) + + describe('saveSnapshot', () => { + beforeEach(() => { + mockAgent = createMockAgent('test-agent') + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + }) + }) + + it('saves snapshot_latest when isLatest is true', async () => { + await sessionManager.saveSnapshot({ target: mockAgent, isLatest: true }) + + const snapshot = await storage.loadSnapshot({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + }) + expect(snapshot).not.toBeNull() + expect(snapshot?.scope).toBe('agent') + }) + + it('saves immutable snapshot when isLatest is false', async () => { + await sessionManager.saveSnapshot({ target: mockAgent, isLatest: false }) + + const ids = await storage.listSnapshotIds({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + }) + expect(ids.length).toBeGreaterThan(0) + }) + + it('allocates unique snapshot IDs', async () => { + await sessionManager.saveSnapshot({ target: mockAgent, isLatest: false }) + await sessionManager.saveSnapshot({ target: mockAgent, isLatest: false }) + await sessionManager.saveSnapshot({ target: mockAgent, isLatest: false }) + + const ids = await storage.listSnapshotIds({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + }) + expect(ids.length).toBe(3) + }) + }) + + describe('listSnapshotIds', () => { + beforeEach(() => { + mockAgent = createMockAgent('test-agent') + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + }) + }) + + it('returns empty array when no snapshots exist', async () => { + const ids = await sessionManager.listSnapshotIds({ target: mockAgent }) + expect(ids).toStrictEqual([]) + }) + + it('returns snapshot IDs for the target agent', async () => { + await sessionManager.saveSnapshot({ target: mockAgent, isLatest: false }) + await sessionManager.saveSnapshot({ target: mockAgent, isLatest: false }) + + const ids = await sessionManager.listSnapshotIds({ target: mockAgent }) + expect(ids).toHaveLength(2) + }) + + it('does not return latest snapshot ID', async () => { + await sessionManager.saveSnapshot({ target: mockAgent, isLatest: true }) + + const ids = await sessionManager.listSnapshotIds({ target: mockAgent }) + expect(ids).toStrictEqual([]) + }) + + it('forwards limit parameter', async () => { + await sessionManager.saveSnapshot({ target: mockAgent, isLatest: false }) + await sessionManager.saveSnapshot({ target: mockAgent, isLatest: false }) + await sessionManager.saveSnapshot({ target: mockAgent, isLatest: false }) + + const ids = await sessionManager.listSnapshotIds({ target: mockAgent, limit: 2 }) + expect(ids).toHaveLength(2) + }) + + it('forwards startAfter parameter', async () => { + await sessionManager.saveSnapshot({ target: mockAgent, isLatest: false }) + await sessionManager.saveSnapshot({ target: mockAgent, isLatest: false }) + + const allIds = await sessionManager.listSnapshotIds({ target: mockAgent }) + const page2 = await sessionManager.listSnapshotIds({ target: mockAgent, startAfter: allIds[0]! }) + expect(page2).toHaveLength(1) + expect(page2[0]).toBe(allIds[1]) + }) + }) + + describe('restoreSnapshot', () => { + beforeEach(() => { + mockAgent = createMockAgent('test-agent') + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + }) + }) + + it('restores snapshot_latest when no snapshotId provided', async () => { + const snapshot = createTestSnapshot() + await storage.saveSnapshot({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + snapshotId: 'latest', + isLatest: true, + snapshot, + }) + + const result = await sessionManager.restoreSnapshot({ target: mockAgent }) + + expect(result).toBe(true) + }) + + it('restores specific snapshot by ID', async () => { + const snapshot = createTestSnapshot() + await storage.saveSnapshot({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + snapshotId: '5', + isLatest: false, + snapshot, + }) + + const result = await sessionManager.restoreSnapshot({ target: mockAgent, snapshotId: '5' }) + + expect(result).toBe(true) + }) + + it('returns false when snapshot not found', async () => { + const result = await sessionManager.restoreSnapshot({ target: mockAgent, snapshotId: '999' }) + + expect(result).toBe(false) + }) + }) + + describe('InitializedEvent handling', () => { + beforeEach(() => { + mockAgent = createMockAgent('test-agent') + }) + + it('loads snapshot_latest on initialization', async () => { + const snapshot = createTestSnapshot() + await storage.saveSnapshot({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + snapshotId: 'latest', + isLatest: true, + snapshot, + }) + + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + }) + + await initPluginAndInvokeHook(sessionManager, new InitializedEvent(createMockEvent(mockAgent))) + + expect(mockAgent.messages).toEqual(snapshot.data.messages) + }) + + it('handles missing snapshot gracefully', async () => { + sessionManager = new SessionManager({ + sessionId: 'new-session', + storage: { snapshot: storage }, + }) + + await expect( + initPluginAndInvokeHook(sessionManager, new InitializedEvent(createMockEvent(mockAgent))) + ).resolves.not.toThrow() + }) + + it('warns when snapshot restore overwrites existing messages', async () => { + const warnSpy = vi.spyOn(logger, 'warn') + + const snapshot = createTestSnapshot() + await storage.saveSnapshot({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + snapshotId: 'latest', + isLatest: true, + snapshot, + }) + + mockAgent.messages.push(MOCK_MESSAGE) + + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + }) + + await initPluginAndInvokeHook(sessionManager, new InitializedEvent(createMockEvent(mockAgent))) + + expect(warnSpy).toHaveBeenCalledWith(expect.stringContaining('overwritten by session restore')) + warnSpy.mockRestore() + }) + + it('does not warn when restoring into agent with no messages', async () => { + const warnSpy = vi.spyOn(logger, 'warn') + + const snapshot = createTestSnapshot() + await storage.saveSnapshot({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + snapshotId: 'latest', + isLatest: true, + snapshot, + }) + + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + }) + + await initPluginAndInvokeHook(sessionManager, new InitializedEvent(createMockEvent(mockAgent))) + + expect(warnSpy).not.toHaveBeenCalled() + warnSpy.mockRestore() + }) + }) + + describe('MessageAddedEvent handling', () => { + beforeEach(() => { + mockAgent = createMockAgent('test-agent') + }) + + it('saves snapshot_latest when saveLatestOn is message', async () => { + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + saveLatestOn: 'message', + }) + + await initPluginAndInvokeHook(sessionManager, new MessageAddedEvent(createMockMessageEvent(mockAgent))) + + const snapshot = await storage.loadSnapshot({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + }) + expect(snapshot).not.toBeNull() + }) + + it('does not save when saveLatestOn is invocation', async () => { + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + saveLatestOn: 'invocation', + }) + + // MessageAddedEvent is not registered when saveLatestOn is 'invocation' + // So we need to call initAgent and check that no hook is registered for MessageAddedEvent + const pluginAgent = createMockAgentWithHooks() + sessionManager.initAgent(pluginAgent) + + // Verify MessageAddedEvent hook is not registered + const messageHook = pluginAgent.trackedHooks.find((h) => h.eventType === MessageAddedEvent) + expect(messageHook).toBeUndefined() + + // Even if we try to invoke (nothing should happen) + const snapshot = await storage.loadSnapshot({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + }) + expect(snapshot).toBeNull() + }) + }) + + describe('AfterInvocationEvent handling', () => { + beforeEach(() => { + mockAgent = createMockAgent('test-agent') + }) + + it('saves snapshot_latest when saveLatestOn is invocation', async () => { + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + saveLatestOn: 'invocation', + }) + + await initPluginAndInvokeHook(sessionManager, new AfterInvocationEvent(createMockEvent(mockAgent))) + + const snapshot = await storage.loadSnapshot({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + }) + expect(snapshot).not.toBeNull() + }) + + it('does not save snapshot_latest when saveLatestOn is trigger', async () => { + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + saveLatestOn: 'trigger', + }) + + await initPluginAndInvokeHook(sessionManager, new AfterInvocationEvent(createMockEvent(mockAgent))) + + const snapshot = await storage.loadSnapshot({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + }) + expect(snapshot).toBeNull() + }) + }) + + describe('snapshotTrigger', () => { + beforeEach(() => { + mockAgent = createMockAgent('test-agent') + }) + + it('creates immutable snapshot when trigger returns true', async () => { + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + saveLatestOn: 'trigger', + snapshotTrigger: () => true, + }) + + await initPluginAndInvokeHook(sessionManager, new AfterInvocationEvent(createMockEvent(mockAgent))) + + const ids = await storage.listSnapshotIds({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + }) + expect(ids.length).toBe(1) + }) + + it('does not create immutable snapshot when trigger returns false', async () => { + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + saveLatestOn: 'trigger', + snapshotTrigger: () => false, + }) + + await initPluginAndInvokeHook(sessionManager, new AfterInvocationEvent(createMockEvent(mockAgent))) + + const ids = await storage.listSnapshotIds({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + }) + expect(ids.length).toBe(0) + }) + + it('provides agentData to trigger', async () => { + const triggerSpy = vi.fn(() => false) + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + saveLatestOn: 'trigger', + snapshotTrigger: triggerSpy, + }) + + await initPluginAndInvokeHook(sessionManager, new AfterInvocationEvent(createMockEvent(mockAgent))) + + expect(triggerSpy).toHaveBeenCalledWith( + expect.objectContaining({ + agentData: expect.objectContaining({ + appState: mockAgent.appState, + messages: mockAgent.messages, + }), + }) + ) + }) + + it('saves both immutable and latest when trigger fires', async () => { + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + saveLatestOn: 'trigger', + snapshotTrigger: () => true, + }) + + await initPluginAndInvokeHook(sessionManager, new AfterInvocationEvent(createMockEvent(mockAgent))) + + const immutableIds = await storage.listSnapshotIds({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + }) + const latest = await storage.loadSnapshot({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + }) + + expect(immutableIds.length).toBe(1) + expect(latest).not.toBeNull() + }) + + it('trigger based on message count via agentData', async () => { + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + saveLatestOn: 'trigger', + snapshotTrigger: ({ agentData }) => agentData.messages.length >= 2, + }) + + const pluginAgent = createMockAgentWithHooks() + sessionManager.initAgent(pluginAgent) + + await invokeTrackedHook(pluginAgent, new AfterInvocationEvent(createMockEvent(mockAgent))) + let ids = await storage.listSnapshotIds({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + }) + expect(ids.length).toBe(0) // 0 messages — no snapshot + + mockAgent.messages.push(MOCK_MESSAGE, MOCK_MESSAGE) + await invokeTrackedHook(pluginAgent, new AfterInvocationEvent(createMockEvent(mockAgent))) + ids = await storage.listSnapshotIds({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + }) + expect(ids.length).toBe(1) // 2 messages — snapshot taken + }) + + it('trigger based on agent state via agentData', async () => { + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + saveLatestOn: 'trigger', + snapshotTrigger: ({ agentData }) => (agentData.appState as any).get('checkpoint') === true, + }) + + const pluginAgent = createMockAgentWithHooks() + sessionManager.initAgent(pluginAgent) + + await invokeTrackedHook(pluginAgent, new AfterInvocationEvent(createMockEvent(mockAgent))) + let ids = await storage.listSnapshotIds({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + }) + expect(ids.length).toBe(0) // state not set — no snapshot + + mockAgent.appState.set('checkpoint', true) + await invokeTrackedHook(pluginAgent, new AfterInvocationEvent(createMockEvent(mockAgent))) + ids = await storage.listSnapshotIds({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + }) + expect(ids.length).toBe(1) // state set — snapshot taken + }) + }) + + describe('integration scenarios', () => { + it('handles complete session lifecycle', async () => { + sessionManager = new SessionManager({ + sessionId: 'lifecycle-test', + storage: { snapshot: storage }, + saveLatestOn: 'invocation', + snapshotTrigger: () => true, + }) + + const pluginAgent = createMockAgentWithHooks() + sessionManager.initAgent(pluginAgent) + + await invokeTrackedHook(pluginAgent, new InitializedEvent(createMockEvent(mockAgent))) + await invokeTrackedHook(pluginAgent, new AfterInvocationEvent(createMockEvent(mockAgent))) + await invokeTrackedHook(pluginAgent, new AfterInvocationEvent(createMockEvent(mockAgent))) + await invokeTrackedHook(pluginAgent, new AfterInvocationEvent(createMockEvent(mockAgent))) + + const latest = await storage.loadSnapshot({ + location: { sessionId: 'lifecycle-test', scope: 'agent', scopeId: 'agent' }, + }) + const immutableIds = await storage.listSnapshotIds({ + location: { sessionId: 'lifecycle-test', scope: 'agent', scopeId: 'agent' }, + }) + + expect(latest).not.toBeNull() + expect(immutableIds.length).toBe(3) + }) + + it('supports resuming from immutable snapshot', async () => { + // First session - snapshot fires when messages.length === 2 (after turn 1) + sessionManager = new SessionManager({ + sessionId: 'resume-test', + storage: { snapshot: storage }, + saveLatestOn: 'trigger', + snapshotTrigger: ({ agentData }) => agentData.messages.length === 2, + }) + + const pluginAgent = createMockAgentWithHooks() + sessionManager.initAgent(pluginAgent) + + await invokeTrackedHook(pluginAgent, new InitializedEvent(createMockEvent(mockAgent))) + mockAgent.messages.push(MOCK_MESSAGE, MOCK_MESSAGE) + await invokeTrackedHook(pluginAgent, new AfterInvocationEvent(createMockEvent(mockAgent))) + + const ids = await storage.listSnapshotIds({ + location: { sessionId: 'resume-test', scope: 'agent', scopeId: 'agent' }, + }) + expect(ids.length).toBe(1) + + // Second session - resume from that snapshot + const newAgent = createMockAgent() + const newSessionManager = new SessionManager({ + sessionId: 'resume-test', + storage: { snapshot: storage }, + saveLatestOn: 'invocation', + }) + + const newAgentData = createMockAgentWithHooks() + newSessionManager.initAgent(newAgentData) + + await invokeTrackedHook(newAgentData, new InitializedEvent(createMockEvent(newAgent))) + await newSessionManager.restoreSnapshot({ target: newAgent, snapshotId: ids[0]! }) + + expect(newAgent.messages).toEqual(mockAgent.messages) + }) + }) + + describe('AfterModelCallEvent with redaction handling', () => { + beforeEach(() => { + mockAgent = createMockAgent('test-agent') + }) + + it('saves snapshot_latest when saveLatestOn is message and redaction occurred', async () => { + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + saveLatestOn: 'message', + }) + + const assistantMessage = new Message({ role: 'assistant', content: [new TextBlock('Response')] }) + const event = new AfterModelCallEvent({ + agent: mockAgent, + model: {} as any, + stopData: { + message: assistantMessage, + stopReason: 'endTurn' as const, + redaction: { userMessage: '[User input redacted.]' }, + }, + } as any) + + await initPluginAndInvokeHook(sessionManager, event) + + const snapshot = await storage.loadSnapshot({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + }) + expect(snapshot).not.toBeNull() + }) + + it('does not save when saveLatestOn is message but no redaction occurred', async () => { + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + saveLatestOn: 'message', + }) + + const assistantMessage = new Message({ role: 'assistant', content: [new TextBlock('Response')] }) + const event = new AfterModelCallEvent({ + agent: mockAgent, + model: {} as any, + stopData: { message: assistantMessage, stopReason: 'endTurn' as const }, + } as any) + + await initPluginAndInvokeHook(sessionManager, event) + + const snapshot = await storage.loadSnapshot({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + }) + expect(snapshot).toBeNull() + }) + + it.each(['invocation', 'message'] as const)( + 'saves snapshot_latest on redaction when saveLatestOn is %s', + async (saveLatestOn) => { + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + saveLatestOn, + }) + + const assistantMessage = new Message({ role: 'assistant', content: [new TextBlock('Response')] }) + const event = new AfterModelCallEvent({ + agent: mockAgent, + model: {} as any, + stopData: { + message: assistantMessage, + stopReason: 'endTurn' as const, + redaction: { userMessage: '[User input redacted.]' }, + }, + } as any) + + await initPluginAndInvokeHook(sessionManager, event) + + const snapshot = await storage.loadSnapshot({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + }) + expect(snapshot).not.toBeNull() + } + ) + + it('does not register AfterModelCallEvent hook when saveLatestOn is trigger', async () => { + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + saveLatestOn: 'trigger', + }) + + const pluginAgent = createMockAgentWithHooks() + sessionManager.initAgent(pluginAgent) + const afterModelHook = pluginAgent.trackedHooks.find((h) => h.eventType === AfterModelCallEvent) + expect(afterModelHook).toBeUndefined() + }) + + it('does not save on AfterModelCallEvent without redaction under saveLatestOn=invocation', async () => { + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + saveLatestOn: 'invocation', + }) + + const assistantMessage = new Message({ role: 'assistant', content: [new TextBlock('Response')] }) + const event = new AfterModelCallEvent({ + agent: mockAgent, + model: {} as any, + stopData: { message: assistantMessage, stopReason: 'endTurn' as const }, + } as any) + + await initPluginAndInvokeHook(sessionManager, event) + + const snapshot = await storage.loadSnapshot({ + location: { sessionId: 'test-session', scope: 'agent', scopeId: 'test-agent' }, + }) + expect(snapshot).toBeNull() + }) + }) +}) + +// --------------------------------------------------------------------------- +// Multi-agent tests +// --------------------------------------------------------------------------- + +type MockOrchestrator = MultiAgent & { + trackedHooks: TrackedHook[] + nodes: ReadonlyMap +} + +function createMockOrchestrator(id = 'graph'): MockOrchestrator { + const trackedHooks: TrackedHook[] = [] + return { + id, + nodes: new Map(), + invoke: vi.fn(), + stream: vi.fn(), + addHook: ( + eventType: HookableEventConstructor, + callback: HookCallback + ): HookCleanup => { + trackedHooks.push({ + eventType: eventType as HookableEventConstructor, + callback: callback as HookCallback, + }) + return () => {} + }, + trackedHooks, + } as unknown as MockOrchestrator +} + +function invokeOrchestratorHook(orchestrator: MockOrchestrator, event: T): Promise { + const hook = orchestrator.trackedHooks.find((h) => h.eventType === event.constructor) + if (!hook) throw new Error(`No hook registered for event type: ${event.constructor.name}`) + return hook.callback(event) as Promise +} + +function createMultiAgentTestSnapshot(orchestratorId = 'test-graph'): ReturnType { + return createTestSnapshot({ scope: 'multiAgent', data: { orchestratorId } }) +} + +describe('SessionManager — multi-agent', () => { + let storage: MockSnapshotStorage + let sessionManager: SessionManager + let orchestrator: MockOrchestrator + + beforeEach(() => { + storage = new MockSnapshotStorage() + orchestrator = createMockOrchestrator('test-graph') + }) + + describe('initMultiAgent', () => { + it('registers BeforeMultiAgentInvocationEvent hook', () => { + sessionManager = new SessionManager({ sessionId: 'test', storage: { snapshot: storage } }) + sessionManager.initMultiAgent(orchestrator) + + const hook = orchestrator.trackedHooks.find((h) => h.eventType === BeforeMultiAgentInvocationEvent) + expect(hook).toBeDefined() + }) + + it('registers AfterNodeCallEvent hook by default (node strategy)', () => { + sessionManager = new SessionManager({ sessionId: 'test', storage: { snapshot: storage } }) + sessionManager.initMultiAgent(orchestrator) + + const hook = orchestrator.trackedHooks.find((h) => h.eventType === AfterNodeCallEvent) + expect(hook).toBeDefined() + }) + + it('registers AfterMultiAgentInvocationEvent hook when strategy is invocation', () => { + sessionManager = new SessionManager({ + sessionId: 'test', + storage: { snapshot: storage }, + multiAgentSaveLatestOn: 'invocation', + }) + sessionManager.initMultiAgent(orchestrator) + + const hook = orchestrator.trackedHooks.find((h) => h.eventType === AfterMultiAgentInvocationEvent) + expect(hook).toBeDefined() + }) + }) + + describe('saveSnapshot — multi-agent', () => { + beforeEach(() => { + sessionManager = new SessionManager({ sessionId: 'test-session', storage: { snapshot: storage } }) + }) + + it('saves orchestrator snapshot as latest', async () => { + await sessionManager.saveSnapshot({ target: orchestrator as unknown as Graph, isLatest: true }) + + const snapshot = await storage.loadSnapshot({ + location: { sessionId: 'test-session', scope: 'multiAgent', scopeId: 'test-graph' }, + }) + expect(snapshot).not.toBeNull() + expect(snapshot?.scope).toBe('multiAgent') + }) + + it('saves orchestrator snapshot with state', async () => { + const state = new MultiAgentState({ nodeIds: ['a'] }) + state.steps = 3 + + await sessionManager.saveSnapshot({ target: orchestrator as unknown as Graph, state, isLatest: true }) + + const snapshot = await storage.loadSnapshot({ + location: { sessionId: 'test-session', scope: 'multiAgent', scopeId: 'test-graph' }, + }) + expect(snapshot).not.toBeNull() + expect(snapshot?.data.state).toBeDefined() + }) + + it('saves immutable orchestrator snapshot', async () => { + await sessionManager.saveSnapshot({ target: orchestrator as unknown as Graph, isLatest: false }) + + const ids = await storage.listSnapshotIds({ + location: { sessionId: 'test-session', scope: 'multiAgent', scopeId: 'test-graph' }, + }) + expect(ids.length).toBe(1) + }) + }) + + describe('restoreSnapshot — multi-agent', () => { + beforeEach(() => { + sessionManager = new SessionManager({ sessionId: 'test-session', storage: { snapshot: storage } }) + }) + + it('restores orchestrator snapshot', async () => { + const snapshot = createMultiAgentTestSnapshot() + await storage.saveSnapshot({ + location: { sessionId: 'test-session', scope: 'multiAgent', scopeId: 'test-graph' }, + snapshotId: 'latest', + isLatest: true, + snapshot, + }) + + const result = await sessionManager.restoreSnapshot({ target: orchestrator as unknown as Graph }) + expect(result).toBe(true) + }) + + it('returns false when no snapshot exists', async () => { + const result = await sessionManager.restoreSnapshot({ target: orchestrator as unknown as Graph }) + expect(result).toBe(false) + }) + }) + + describe('AfterMultiAgentInvocationEvent handling', () => { + it('saves snapshot after node call when multiAgentSaveLatestOn is node (default)', async () => { + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + }) + sessionManager.initMultiAgent(orchestrator) + + const state = new MultiAgentState({ nodeIds: ['a'] }) + await invokeOrchestratorHook( + orchestrator, + new AfterNodeCallEvent({ orchestrator, state, nodeId: 'a', invocationState: {} }) + ) + + const snapshot = await storage.loadSnapshot({ + location: { sessionId: 'test-session', scope: 'multiAgent', scopeId: 'test-graph' }, + }) + expect(snapshot).not.toBeNull() + expect(snapshot?.scope).toBe('multiAgent') + }) + + it('saves snapshot after invocation when multiAgentSaveLatestOn is invocation', async () => { + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + multiAgentSaveLatestOn: 'invocation', + }) + sessionManager.initMultiAgent(orchestrator) + + const state = new MultiAgentState({ nodeIds: ['a'] }) + await invokeOrchestratorHook( + orchestrator, + new AfterMultiAgentInvocationEvent({ orchestrator, state, invocationState: {} }) + ) + + const snapshot = await storage.loadSnapshot({ + location: { sessionId: 'test-session', scope: 'multiAgent', scopeId: 'test-graph' }, + }) + expect(snapshot).not.toBeNull() + expect(snapshot?.scope).toBe('multiAgent') + }) + + it('saves snapshot independently of agent saveLatestOn setting', async () => { + sessionManager = new SessionManager({ + sessionId: 'test-session', + storage: { snapshot: storage }, + saveLatestOn: 'trigger', + multiAgentSaveLatestOn: 'invocation', + }) + sessionManager.initMultiAgent(orchestrator) + + const state = new MultiAgentState({ nodeIds: ['a'] }) + await invokeOrchestratorHook( + orchestrator, + new AfterMultiAgentInvocationEvent({ orchestrator, state, invocationState: {} }) + ) + + const snapshot = await storage.loadSnapshot({ + location: { sessionId: 'test-session', scope: 'multiAgent', scopeId: 'test-graph' }, + }) + expect(snapshot).not.toBeNull() + }) + }) + + describe('scope isolation', () => { + it('agent and multi-agent snapshots use separate storage paths', async () => { + const mockAgent = createMockAgent('test-agent') + sessionManager = new SessionManager({ + sessionId: 'shared-session', + storage: { snapshot: storage }, + }) + + await sessionManager.saveSnapshot({ target: mockAgent as unknown as Agent, isLatest: true }) + await sessionManager.saveSnapshot({ target: orchestrator as unknown as Graph, isLatest: true }) + + const agentSnapshot = await storage.loadSnapshot({ + location: { sessionId: 'shared-session', scope: 'agent', scopeId: 'test-agent' }, + }) + const multiAgentSnapshot = await storage.loadSnapshot({ + location: { sessionId: 'shared-session', scope: 'multiAgent', scopeId: 'test-graph' }, + }) + + expect(agentSnapshot).not.toBeNull() + expect(multiAgentSnapshot).not.toBeNull() + expect(agentSnapshot?.scope).toBe('agent') + expect(multiAgentSnapshot?.scope).toBe('multiAgent') + }) + }) + + describe('BeforeMultiAgentInvocationEvent — state restore', () => { + it('restores state into event.state when snapshot exists', async () => { + const snapshot = createMultiAgentTestSnapshot() + // Build state with a completed node and result + const state = new MultiAgentState({ nodeIds: ['a'] }) + state.steps = 7 + const nodeState = state.node('a')! + nodeState.status = Status.COMPLETED + nodeState.results.push( + new NodeResult({ nodeId: 'a', status: Status.COMPLETED, duration: 100, content: [new TextBlock('done')] }) + ) + const { serializeStateSerializable } = await import('../../types/serializable.js') + snapshot.data.state = serializeStateSerializable(state) + + await storage.saveSnapshot({ + location: { sessionId: 'test-session', scope: 'multiAgent', scopeId: 'test-graph' }, + snapshotId: 'latest', + isLatest: true, + snapshot, + }) + + sessionManager = new SessionManager({ sessionId: 'test-session', storage: { snapshot: storage } }) + sessionManager.initMultiAgent(orchestrator) + + const freshState = new MultiAgentState({ nodeIds: ['a'] }) + await invokeOrchestratorHook( + orchestrator, + new BeforeMultiAgentInvocationEvent({ orchestrator, state: freshState, invocationState: {} }) + ) + + expect(freshState.steps).toBe(7) + expect(freshState.node('a')?.status).toBe(Status.COMPLETED) + expect(freshState.node('a')?.results).toHaveLength(1) + expect(freshState.node('a')?.results[0]?.nodeId).toBe('a') + expect(freshState.node('a')?.results[0]?.status).toBe(Status.COMPLETED) + expect(freshState.node('a')?.content[0]).toEqual(expect.objectContaining({ text: 'done' })) + }) + + it('does not modify state when no snapshot exists', async () => { + sessionManager = new SessionManager({ sessionId: 'empty-session', storage: { snapshot: storage } }) + sessionManager.initMultiAgent(orchestrator) + + const freshState = new MultiAgentState({ nodeIds: ['a'] }) + await invokeOrchestratorHook( + orchestrator, + new BeforeMultiAgentInvocationEvent({ orchestrator, state: freshState, invocationState: {} }) + ) + + expect(freshState.steps).toBe(0) + }) + + it('restores state independently for two orchestrators sharing one SessionManager', async () => { + const { serializeStateSerializable } = await import('../../types/serializable.js') + + // Set up snapshots for two different orchestrators + const orchestratorA = createMockOrchestrator('graph-a') + const orchestratorB = createMockOrchestrator('swarm-b') + + for (const [orch, steps] of [ + [orchestratorA, 3], + [orchestratorB, 5], + ] as const) { + const snap = createMultiAgentTestSnapshot(orch.id) + const st = new MultiAgentState({ nodeIds: ['x'] }) + st.steps = steps + snap.data.state = serializeStateSerializable(st) + await storage.saveSnapshot({ + location: { sessionId: 'test-session', scope: 'multiAgent', scopeId: orch.id }, + snapshotId: 'latest', + isLatest: true, + snapshot: snap, + }) + } + + sessionManager = new SessionManager({ sessionId: 'test-session', storage: { snapshot: storage } }) + sessionManager.initMultiAgent(orchestratorA) + sessionManager.initMultiAgent(orchestratorB) + + // First orchestrator restores its own state + const stateA = new MultiAgentState({ nodeIds: ['x'] }) + await invokeOrchestratorHook( + orchestratorA, + new BeforeMultiAgentInvocationEvent({ orchestrator: orchestratorA, state: stateA, invocationState: {} }) + ) + expect(stateA.steps).toBe(3) + + // Second orchestrator also restores — not blocked by the first + const stateB = new MultiAgentState({ nodeIds: ['x'] }) + await invokeOrchestratorHook( + orchestratorB, + new BeforeMultiAgentInvocationEvent({ orchestrator: orchestratorB, state: stateB, invocationState: {} }) + ) + expect(stateB.steps).toBe(5) + }) + + it('consumes snapshot once — second invocation gets fresh state', async () => { + const snapshot = createMultiAgentTestSnapshot() + const state = new MultiAgentState({ nodeIds: ['a'] }) + state.steps = 7 + const { serializeStateSerializable } = await import('../../types/serializable.js') + snapshot.data.state = serializeStateSerializable(state) + + await storage.saveSnapshot({ + location: { sessionId: 'test-session', scope: 'multiAgent', scopeId: 'test-graph' }, + snapshotId: 'latest', + isLatest: true, + snapshot, + }) + + sessionManager = new SessionManager({ sessionId: 'test-session', storage: { snapshot: storage } }) + sessionManager.initMultiAgent(orchestrator) + + // First invocation — state is restored + const firstState = new MultiAgentState({ nodeIds: ['a'] }) + await invokeOrchestratorHook( + orchestrator, + new BeforeMultiAgentInvocationEvent({ orchestrator, state: firstState, invocationState: {} }) + ) + expect(firstState.steps).toBe(7) + + // Second invocation — snapshot already consumed + const secondState = new MultiAgentState({ nodeIds: ['a'] }) + await invokeOrchestratorHook( + orchestrator, + new BeforeMultiAgentInvocationEvent({ orchestrator, state: secondState, invocationState: {} }) + ) + expect(secondState.steps).toBe(0) + }) + }) +}) diff --git a/strands-ts/src/session/__tests__/validation.test.ts b/strands-ts/src/session/__tests__/validation.test.ts new file mode 100644 index 0000000000..abd25547cd --- /dev/null +++ b/strands-ts/src/session/__tests__/validation.test.ts @@ -0,0 +1,62 @@ +import { describe, expect, it } from 'vitest' +import { validateIdentifier, validateUuidV7 } from '../validation.js' + +describe('validateIdentifier', () => { + describe('when identifier is valid', () => { + it('returns the identifier', () => { + expect(validateIdentifier('valid-id')).toBe('valid-id') + }) + }) + + describe('when identifier contains forward slash', () => { + it('throws error', () => { + expect(() => validateIdentifier('invalid/id')).toThrow( + "Identifier 'invalid/id' can only contain lowercase letters, numbers, hyphens, and underscores" + ) + }) + }) + + describe('when identifier contains backslash', () => { + it('throws error', () => { + expect(() => validateIdentifier('invalid\\id')).toThrow( + "Identifier 'invalid\\id' can only contain lowercase letters, numbers, hyphens, and underscores" + ) + }) + }) +}) + +describe('validateUuidV7', () => { + describe('when id is a valid UUID v7', () => { + it('does not throw', () => { + expect(() => validateUuidV7('01956891-2b4c-7000-8abc-123456789abc')).not.toThrow() + }) + }) + + describe('when id is a UUID v4 (wrong version)', () => { + it('throws error', () => { + expect(() => validateUuidV7('550e8400-e29b-41d4-a716-446655440000')).toThrow( + "'550e8400-e29b-41d4-a716-446655440000' is not a valid UUID v7 snapshot ID" + ) + }) + }) + + describe('when id is a timestamp string', () => { + it('throws error', () => { + expect(() => validateUuidV7('2025-01-15T10:30:00Z')).toThrow( + "'2025-01-15T10:30:00Z' is not a valid UUID v7 snapshot ID" + ) + }) + }) + + describe('when id contains path traversal', () => { + it('throws error', () => { + expect(() => validateUuidV7('../evil')).toThrow("'../evil' is not a valid UUID v7 snapshot ID") + }) + }) + + describe('when id is empty string', () => { + it('throws error', () => { + expect(() => validateUuidV7('')).toThrow("'' is not a valid UUID v7 snapshot ID") + }) + }) +}) diff --git a/strands-ts/src/session/file-storage.ts b/strands-ts/src/session/file-storage.ts new file mode 100644 index 0000000000..41b05942ef --- /dev/null +++ b/strands-ts/src/session/file-storage.ts @@ -0,0 +1,219 @@ +import type { SnapshotStorage, SnapshotLocation } from './storage.js' +import type { Snapshot, SnapshotManifest } from './types.js' + +import { SessionError } from '../errors.js' +import { validateIdentifier, validateUuidV7 } from './validation.js' + +const MANIFEST = 'manifest.json' +const SNAPSHOT_LATEST = 'snapshot_latest.json' +const IMMUTABLE_HISTORY = 'immutable_history' +const SNAPSHOT_REGEX = /snapshot_([\w-]+)\.json$/ +const SCHEMA_VERSION = '1.0' + +/** + * File-based implementation of SnapshotStorage. + * Persists session snapshots to the local filesystem under a configurable base directory. + * + * Directory layout: + * ``` + * //scopes///snapshots/ + * snapshot_latest.json + * immutable_history/ + * snapshot_.json + * ``` + */ +export class FileStorage implements SnapshotStorage { + /** Absolute path to the root directory where all session data is stored. */ + private readonly _baseDir: string + + /** + * @param baseDir - Absolute path to the root directory for storing session snapshots. + */ + constructor(baseDir: string) { + this._baseDir = baseDir + } + + /** + * Resolves the absolute file path for a given scope location and filename. + * Validates sessionId and scopeId before constructing the path. + */ + private async _getPath(location: SnapshotLocation, filename: string): Promise { + const { join } = await import('path') + validateIdentifier(location.sessionId) + validateIdentifier(location.scopeId) + return join(this._baseDir, location.sessionId, 'scopes', location.scope, location.scopeId, 'snapshots', filename) + } + + /** + * Resolves the absolute path to the root directory for a session. + * Used by deleteSession to remove all data under `//`. + */ + private async _getSessionDir(sessionId: string): Promise { + const { join } = await import('path') + validateIdentifier(sessionId) + return join(this._baseDir, sessionId) + } + + /** + * Persists a snapshot to disk. + * If `isLatest` is true, writes to `snapshot_latest.json` (overwriting any previous). + * Otherwise, writes to `immutable_history/snapshot_.json`. + */ + async saveSnapshot(params: { + location: SnapshotLocation + snapshotId: string + isLatest: boolean + snapshot: Snapshot + }): Promise { + const path = params.isLatest + ? await this._getLatestSnapshotPath(params.location) + : await this._getHistorySnapshotPath(params.location, params.snapshotId) + await this._writeJSON(path, params.snapshot) + } + + /** + * Loads a snapshot from disk. + * If `snapshotId` is omitted, loads `snapshot_latest.json`. + * Returns null if the file does not exist. + */ + async loadSnapshot(params: { location: SnapshotLocation; snapshotId?: string }): Promise { + const path = + params.snapshotId === undefined + ? await this._getLatestSnapshotPath(params.location) + : await this._getHistorySnapshotPath(params.location, params.snapshotId) + return this._readJSON(path) + } + + /** + * Lists immutable snapshot IDs for a scope, sorted chronologically. + * Since IDs are UUID v7, lexicographic sort equals chronological order. + * `startAfter` filters to IDs after the given UUID v7 (exclusive cursor). + * `limit` caps the number of returned IDs. + * Returns an empty array if no snapshots exist yet. + */ + async listSnapshotIds(params: { + location: SnapshotLocation + limit?: number + startAfter?: string + }): Promise { + if (params.limit !== undefined && params.limit <= 0) return [] + if (params.startAfter) validateUuidV7(params.startAfter) + const dirPath = await this._getPath(params.location, IMMUTABLE_HISTORY) + try { + const { promises: fs } = await import('fs') + const files = await fs.readdir(dirPath) + let ids = files + .map((file) => file.match(SNAPSHOT_REGEX)?.[1]) + .filter((id): id is string => id !== undefined) + .sort() + if (params.startAfter) { + ids = ids.filter((id) => id > params.startAfter!) + } + if (params.limit !== undefined) { + ids = ids.slice(0, params.limit) + } + return ids + } catch (error: unknown) { + if (this._isFileNotFoundError(error)) return [] + throw new SessionError(`Failed to list snapshots for session ${params.location.sessionId}`, { cause: error }) + } + } + + /** + * Deletes all data for a session by removing its root directory (`//`) recursively. + * No-ops if the session directory does not exist. + */ + async deleteSession(params: { sessionId: string }): Promise { + const sessionDir = await this._getSessionDir(params.sessionId) + try { + const { promises: fs } = await import('fs') + await fs.rm(sessionDir, { recursive: true, force: true }) + } catch (error: unknown) { + throw new SessionError(`Failed to delete session ${params.sessionId}`, { cause: error }) + } + } + + /** + * Loads the snapshot manifest for a scope. + * Returns a default manifest with the current timestamp if none exists yet. + */ + async loadManifest(params: { location: SnapshotLocation }): Promise { + const path = await this._getPath(params.location, MANIFEST) + const manifest = await this._readJSON(path) + + return ( + manifest ?? { + schemaVersion: SCHEMA_VERSION, + updatedAt: new Date().toISOString(), + } + ) + } + + /** + * Persists the snapshot manifest for a scope to disk. + */ + async saveManifest(params: { location: SnapshotLocation; manifest: SnapshotManifest }): Promise { + const path = await this._getPath(params.location, MANIFEST) + await this._writeJSON(path, params.manifest) + } + + /** + * Atomically writes JSON to a file using a `.tmp` intermediary to prevent partial writes. + * Creates parent directories if they do not exist. + */ + private async _writeJSON(path: string, data: unknown): Promise { + try { + const { promises: fs } = await import('fs') + const { dirname } = await import('path') + await fs.mkdir(dirname(path), { recursive: true }) + const tmpPath = `${path}.tmp` + await fs.writeFile(tmpPath, JSON.stringify(data, null, 2), 'utf8') + await fs.rename(tmpPath, path) + } catch (error: unknown) { + throw new SessionError(`Failed to write file ${path}`, { cause: error }) + } + } + + /** + * Reads and parses a JSON file. Returns null if the file does not exist. + * Throws SessionError on parse failure or unexpected filesystem errors. + */ + private async _readJSON(path: string): Promise { + try { + const { promises: fs } = await import('fs') + const content = await fs.readFile(path, 'utf8') + return JSON.parse(content) + } catch (error: unknown) { + if (this._isFileNotFoundError(error)) { + return null + } + if (error instanceof SyntaxError) { + throw new SessionError(`Invalid JSON in file ${path}`, { cause: error }) + } + throw new SessionError(`File system error reading ${path}`, { cause: error }) + } + } + + /** Returns true if the error represents a missing file or directory (ENOENT). */ + private _isFileNotFoundError(error: unknown): boolean { + return error !== null && typeof error === 'object' && 'code' in error && error.code === 'ENOENT' + } + + /** Returns the file path for `snapshot_latest.json` within the given scope. */ + private async _getLatestSnapshotPath(location: SnapshotLocation): Promise { + return this._getPath(location, SNAPSHOT_LATEST) + } + + /** + * Returns the file path for an immutable snapshot in `immutable_history/`. + * Validates the snapshotId and guards against path traversal outside `_baseDir`. + */ + private async _getHistorySnapshotPath(location: SnapshotLocation, snapshotId: string): Promise { + validateIdentifier(snapshotId) + const resolved = await this._getPath(location, `${IMMUTABLE_HISTORY}/snapshot_${snapshotId}.json`) + if (!resolved.startsWith(this._baseDir)) { + throw new SessionError(`Invalid snapshotId '${snapshotId}': resolves outside storage directory`) + } + return resolved + } +} diff --git a/strands-ts/src/session/index.ts b/strands-ts/src/session/index.ts new file mode 100644 index 0000000000..6328f9ff9b --- /dev/null +++ b/strands-ts/src/session/index.ts @@ -0,0 +1,17 @@ +/** + * Session management module re-exports. + * These are exported from the main `@strands-agents/sdk` entry point. + */ + +// Core types +export { SessionManager } from './session-manager.js' +export type { SessionManagerConfig, SaveLatestStrategy, MultiAgentSaveLatestStrategy } from './session-manager.js' +export type { SnapshotManifest, SnapshotTriggerCallback, SnapshotTriggerParams } from './types.js' + +// Storage layer +export type { SessionStorage, SnapshotStorage, SnapshotLocation } from './storage.js' + +// Storage implementations +export { FileStorage } from './file-storage.js' + +export type { Scope, Snapshot } from '../types/snapshot.js' diff --git a/strands-ts/src/session/s3-storage.ts b/strands-ts/src/session/s3-storage.ts new file mode 100644 index 0000000000..b60aa4b6b1 --- /dev/null +++ b/strands-ts/src/session/s3-storage.ts @@ -0,0 +1,283 @@ +import { + S3Client, + PutObjectCommand, + GetObjectCommand, + ListObjectsV2Command, + DeleteObjectsCommand, +} from '@aws-sdk/client-s3' +import type { SnapshotStorage, SnapshotLocation } from './storage.js' +import type { Snapshot, SnapshotManifest } from './types.js' +import { SessionError } from '../errors.js' +import { validateIdentifier, validateUuidV7 } from './validation.js' + +const MANIFEST = 'manifest.json' +const SNAPSHOT_LATEST = 'snapshot_latest.json' +const IMMUTABLE_HISTORY = 'immutable_history/' +const SCHEMA_VERSION = '1.0' +const SNAPSHOT_REGEX = /snapshot_([\w-]+)\.json$/ +const S3_PAGE_SIZE = 1000 + +/** + * Configuration options for S3Storage + */ +export type S3StorageConfig = { + /** S3 bucket name */ + bucket: string + /** Optional key prefix for all objects */ + prefix?: string + /** AWS region (default: us-east-1). Cannot be used with s3Client */ + region?: string + /** Pre-configured S3 client. Cannot be used with region */ + s3Client?: S3Client +} + +/** + * S3-based implementation of SnapshotStorage. + * Persists session snapshots as JSON objects in an S3 bucket. + * + * Object key layout: + * ``` + * [/]/scopes///snapshots/ + * snapshot_latest.json + * immutable_history/ + * snapshot_.json + * ``` + */ +export class S3Storage implements SnapshotStorage { + /** S3 client instance */ + private readonly _s3: S3Client + /** S3 bucket name */ + private readonly _bucket: string + /** Key prefix for all objects */ + private readonly _prefix: string + /** + * Creates new S3Storage instance + * @param config - Configuration options + */ + constructor(config: S3StorageConfig) { + if (config.s3Client && config.region) { + throw new SessionError('Cannot specify both s3Client and region. Configure region in the S3Client instead.') + } + + this._bucket = config.bucket + this._prefix = config.prefix ?? '' + this._s3 = config.s3Client ?? new S3Client({ region: config.region ?? 'us-east-1' }) + } + + /** + * Resolves the full S3 object key for a given scope location and path. + * Validates sessionId and scopeId before constructing the key. + */ + private _getKey(location: SnapshotLocation, path: string): string { + validateIdentifier(location.sessionId) + validateIdentifier(location.scopeId) + const base = this._prefix ? `${this._prefix}/` : '' + return `${base}${location.sessionId}/scopes/${location.scope}/${location.scopeId}/snapshots/${path}` + } + + /** + * Resolves the S3 key prefix for an entire session (`[/]/`). + * Used by deleteSession to list and remove all objects under the session. + */ + private _getSessionPrefix(sessionId: string): string { + validateIdentifier(sessionId) + const base = this._prefix ? `${this._prefix}/` : '' + return `${base}${sessionId}/` + } + + /** + * Persists a snapshot to S3. + * If `isLatest` is true, writes to `snapshot_latest.json` (overwriting any previous). + * Otherwise, writes to `immutable_history/snapshot_.json`. + */ + async saveSnapshot(params: { + location: SnapshotLocation + snapshotId: string + isLatest: boolean + snapshot: Snapshot + }): Promise { + if (!params.isLatest) { + await this._writeJSON(this._getHistorySnapshotKey(params.location, params.snapshotId), params.snapshot) + } else { + await this._writeJSON(this._getLatestSnapshotKey(params.location), params.snapshot) + } + } + + /** + * Loads a snapshot from S3. + * If `snapshotId` is omitted, loads `snapshot_latest.json`. + * Returns null if the object does not exist. + */ + async loadSnapshot(params: { location: SnapshotLocation; snapshotId?: string }): Promise { + const key = + params.snapshotId === undefined + ? this._getLatestSnapshotKey(params.location) + : this._getHistorySnapshotKey(params.location, params.snapshotId) + return this._readJSON(key) + } + + /** + * Lists immutable snapshot IDs for a scope, sorted chronologically. + * Since IDs are UUID v7, lexicographic sort equals chronological order. + * Pushes `startAfter` and `limit` down to S3 via `StartAfter` and `MaxKeys` + * to avoid fetching unnecessary objects. + * Returns an empty array if no snapshots exist yet. + */ + async listSnapshotIds(params: { + location: SnapshotLocation + limit?: number + startAfter?: string + }): Promise { + if (params.limit !== undefined && params.limit <= 0) return [] + if (params.startAfter) validateUuidV7(params.startAfter) + + const prefix = this._getKey(params.location, IMMUTABLE_HISTORY) + // S3 StartAfter is a full object key; construct it from the UUID cursor. + // Exclusive: objects after this key are returned, matching our pagination contract. + const startAfterKey = params.startAfter + ? this._getHistorySnapshotKey(params.location, params.startAfter) + : undefined + try { + const ids: string[] = [] + let continuationToken: string | undefined + do { + const response = await this._s3.send( + new ListObjectsV2Command({ + Bucket: this._bucket, + Prefix: prefix, + StartAfter: continuationToken ? undefined : startAfterKey, + MaxKeys: params.limit !== undefined ? Math.min(S3_PAGE_SIZE, params.limit - ids.length) : S3_PAGE_SIZE, + ContinuationToken: continuationToken, + }) + ) + const page = (response.Contents ?? []) + .map((obj) => obj.Key?.match(SNAPSHOT_REGEX)?.[1]) + .filter((id): id is string => id !== undefined) + ids.push(...page) + if (response.IsTruncated) { + if (!response.NextContinuationToken) { + throw new SessionError('S3 returned truncated response without continuation token') + } + continuationToken = response.NextContinuationToken + } else { + continuationToken = undefined + } + } while (continuationToken && (params.limit === undefined || ids.length < params.limit)) + return params.limit !== undefined ? ids.slice(0, params.limit) : ids + } catch (error: unknown) { + if (error instanceof SessionError) throw error + if (this._isNotFoundError(error)) return [] + throw new SessionError(`Failed to list snapshots for session ${params.location.sessionId}`, { cause: error }) + } + } + + /** + * Deletes all S3 objects belonging to a session by listing and batch-deleting + * everything under `[/]/`. + * Handles buckets with more than 1000 objects via continuation token pagination. + * No-ops if the session has no objects. + */ + async deleteSession(params: { sessionId: string }): Promise { + const prefix = this._getSessionPrefix(params.sessionId) + try { + let continuationToken: string | undefined + do { + const response = await this._s3.send( + new ListObjectsV2Command({ Bucket: this._bucket, Prefix: prefix, ContinuationToken: continuationToken }) + ) + const keys = (response.Contents ?? []).map((obj) => ({ Key: obj.Key! })) + if (keys.length > 0) { + await this._s3.send(new DeleteObjectsCommand({ Bucket: this._bucket, Delete: { Objects: keys } })) + } + continuationToken = response.IsTruncated ? response.NextContinuationToken : undefined + } while (continuationToken) + } catch (error: unknown) { + throw new SessionError(`Failed to delete session ${params.sessionId}`, { cause: error }) + } + } + + /** + * Loads the snapshot manifest for a scope from S3. + * Returns a default manifest with the current timestamp if none exists yet. + */ + async loadManifest(params: { location: SnapshotLocation }): Promise { + const key = this._getKey(params.location, MANIFEST) + const manifest = await this._readJSON(key) + + return ( + manifest ?? { + schemaVersion: SCHEMA_VERSION, + updatedAt: new Date().toISOString(), + } + ) + } + + /** + * Persists the snapshot manifest for a scope to S3. + */ + async saveManifest(params: { location: SnapshotLocation; manifest: SnapshotManifest }): Promise { + const key = this._getKey(params.location, MANIFEST) + await this._writeJSON(key, params.manifest) + } + + /** + * Serializes data as JSON and writes it to S3 with `application/json` content type. + */ + private async _writeJSON(key: string, data: unknown): Promise { + try { + await this._s3.send( + new PutObjectCommand({ + Bucket: this._bucket, + Key: key, + Body: JSON.stringify(data, null, 2), + ContentType: 'application/json', + }) + ) + } catch (error) { + throw new SessionError(`Failed to write S3 object ${key}`, { cause: error }) + } + } + + /** + * Reads and parses a JSON object from S3. Returns null if the object does not exist. + * Throws SessionError on parse failure or unexpected S3 errors. + */ + private async _readJSON(key: string): Promise { + try { + const response = await this._s3.send(new GetObjectCommand({ Bucket: this._bucket, Key: key })) + const body = await response.Body?.transformToString() + if (!body) return null + return JSON.parse(body) + } catch (error: unknown) { + if (this._isNotFoundError(error)) { + return null + } + if (error instanceof SyntaxError) { + throw new SessionError(`Invalid JSON in S3 object ${key}`, { cause: error }) + } + throw new SessionError(`S3 error reading ${key}`, { cause: error }) + } + } + + /** Returns true if the error represents a missing S3 object (`NoSuchKey`) or bucket (`NoSuchBucket`). */ + private _isNotFoundError(error: unknown): error is { name: string } { + return ( + error !== null && + typeof error === 'object' && + 'name' in error && + typeof (error as { name: unknown }).name === 'string' && + ((error as { name: string }).name === 'NoSuchKey' || (error as { name: string }).name === 'NoSuchBucket') + ) + } + + /** Returns the S3 key for `snapshot_latest.json` within the given scope. */ + private _getLatestSnapshotKey(location: SnapshotLocation): string { + return this._getKey(location, SNAPSHOT_LATEST) + } + + /** Returns the S3 key for an immutable snapshot in `immutable_history/`. Validates the snapshotId before constructing the key. */ + private _getHistorySnapshotKey(location: SnapshotLocation, snapshotId: string): string { + validateIdentifier(snapshotId) + return this._getKey(location, `${IMMUTABLE_HISTORY}snapshot_${snapshotId}.json`) + } +} diff --git a/strands-ts/src/session/session-manager.ts b/strands-ts/src/session/session-manager.ts new file mode 100644 index 0000000000..b36d25d40a --- /dev/null +++ b/strands-ts/src/session/session-manager.ts @@ -0,0 +1,324 @@ +import type { SnapshotStorage, SnapshotLocation } from './storage.js' +import { validateIdentifier } from './validation.js' +import type { SnapshotTriggerCallback } from './types.js' +import type { Plugin } from '../plugins/plugin.js' +import type { LocalAgent } from '../types/agent.js' +import { AfterInvocationEvent, AfterModelCallEvent, InitializedEvent, MessageAddedEvent } from '../hooks/events.js' +import { v7 as uuidV7 } from 'uuid' +import { logger } from '../logging/logger.js' +import type { MultiAgentPlugin, MultiAgent } from '../multiagent/index.js' +import { MultiAgentState } from '../multiagent/state.js' +import { + takeSnapshot as takeMultiAgentSnapshot, + loadSnapshot as loadMultiAgentSnapshot, +} from '../multiagent/snapshot.js' +import { + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, +} from '../multiagent/events.js' +import type { Graph } from '../multiagent/graph.js' +import type { Swarm } from '../multiagent/swarm.js' + +/** + * Controls when `snapshot_latest` is saved automatically for agents. + * + * There are two kinds of snapshots: + * - **`snapshot_latest`**: A single mutable snapshot that is overwritten on each save. Used to + * resume the most recent conversation state (e.g. after a crash or restart). Always reflects + * the last saved point in time. + * - **Immutable snapshots**: Append-only snapshots with unique IDs (UUID v7), created only when + * `snapshotTrigger` fires. Used for checkpointing — you can restore to any prior state, not + * just the latest. + * + * `SaveLatestStrategy` controls how frequently `snapshot_latest` is updated: + * - `'invocation'`: after every agent invocation completes (default; balances durability and I/O) + * - `'message'`: after every message added (most durable, highest I/O) + * - `'trigger'`: only when a `snapshotTrigger` fires (or manually via `saveSnapshot`) + * + * Under `'invocation'` and `'message'`, guardrail redactions are persisted immediately so + * pre-redaction content never sits at rest. Under `'trigger'`, the caller's `snapshotTrigger` + * stays in control; redactions are only flushed if the trigger fires or `saveSnapshot` is called. + */ +export type SaveLatestStrategy = 'message' | 'invocation' | 'trigger' + +/** + * Controls when `snapshot_latest` is saved for multi-agent orchestrators. + * + * - `'node'`: after every node invocation completes (default; enables resume + * from the last completed node after a crash or restart) + * - `'invocation'`: after every orchestrator invocation completes (lower I/O, + * but only captures state at orchestrator invocation boundaries) + */ +export type MultiAgentSaveLatestStrategy = 'node' | 'invocation' + +export interface SessionManagerConfig { + /** Pluggable storage backends for snapshot persistence. Defaults to FileStorage in Node.js; required in browser environments. */ + storage: { + snapshot: SnapshotStorage + } + /** Unique session identifier. Defaults to `'default-session'`. */ + sessionId?: string + /** When to save snapshot_latest. Default: `'invocation'` (after each agent invocation completes). See {@link SaveLatestStrategy} for details. */ + saveLatestOn?: SaveLatestStrategy + /** Callback invoked after each invocation to decide whether to create an immutable snapshot. */ + snapshotTrigger?: SnapshotTriggerCallback + /** + * When to save snapshot_latest for multi-agent orchestrators. + * Default: `'node'` (after each node invocation completes). + * See {@link MultiAgentSaveLatestStrategy} for details. + */ + multiAgentSaveLatestOn?: MultiAgentSaveLatestStrategy +} + +/** + * Manages session persistence for agents, enabling conversation state + * to be saved and restored across invocations using pluggable storage backends. + * + * Also supports multi-agent orchestrators (Graph, Swarm) via the MultiAgentPlugin interface. + * Scope is auto-detected based on whether initAgent or initMultiAgent is called. + * + * @example + * ```typescript + * import { SessionManager, FileStorage } from '@strands-agents/sdk' + * + * const session = new SessionManager({ + * sessionId: 'my-session', + * storage: { snapshot: new FileStorage() }, + * }) + * const agent = new Agent({ sessionManager: session }) + * ``` + */ +export class SessionManager implements Plugin, MultiAgentPlugin { + private readonly _sessionId: string + private readonly _storage: { snapshot: SnapshotStorage } + private readonly _saveLatestOn: SaveLatestStrategy + private readonly _snapshotTrigger?: SnapshotTriggerCallback | undefined + private readonly _multiAgentSaveLatestOn: MultiAgentSaveLatestStrategy + private _multiAgentRestoredIds = new Set() + + /** + * Unique identifier for this plugin. + */ + get name(): string { + return 'strands:session-manager' + } + + constructor(config: SessionManagerConfig) { + this._sessionId = validateIdentifier(config.sessionId ?? 'default-session') + this._storage = { snapshot: config.storage.snapshot } + this._saveLatestOn = config.saveLatestOn ?? 'invocation' + this._multiAgentSaveLatestOn = config.multiAgentSaveLatestOn ?? 'node' + this._snapshotTrigger = config.snapshotTrigger + } + + /** Initializes the plugin by registering lifecycle hook callbacks. */ + public initAgent(agent: LocalAgent): void { + agent.addHook(InitializedEvent, async (event) => { + await this._onAgentInitialized(event) + }) + if (this._saveLatestOn === 'message') { + agent.addHook(MessageAddedEvent, async (event) => { + await this._onMessageAdded(event) + }) + } + + // Persist guardrail redactions immediately for auto-save strategies. + // 'trigger' is an explicit opt-out from auto-saves, so the caller's snapshotTrigger + // stays in control there. + if (this._saveLatestOn !== 'trigger') { + agent.addHook(AfterModelCallEvent, async (event) => { + await this._onAfterModelCall(event) + }) + } + agent.addHook(AfterInvocationEvent, async (event) => { + await this._onAfterAgentInvocation(event) + }) + } + + private _location(agent: LocalAgent): SnapshotLocation { + return { sessionId: this._sessionId, scope: 'agent', scopeId: agent.id } + } + + /** Saves a snapshot of the target's current state. */ + async saveSnapshot(params: { target: LocalAgent; isLatest: boolean }): Promise + async saveSnapshot(params: { target: Graph | Swarm; state?: MultiAgentState; isLatest: boolean }): Promise + async saveSnapshot(params: { + target: LocalAgent | Graph | Swarm + state?: MultiAgentState + isLatest: boolean + }): Promise { + const isAgent = 'messages' in params.target + const snapshot = isAgent + ? (params.target as LocalAgent).takeSnapshot({ preset: 'session' }) + : takeMultiAgentSnapshot(params.target as Graph | Swarm, params.state) + const snapshotId = params.isLatest ? 'latest' : uuidV7() + const location = isAgent + ? this._location(params.target as LocalAgent) + : this._multiAgentLocation(params.target as MultiAgent) + await this._storage.snapshot.saveSnapshot({ location, snapshotId, isLatest: params.isLatest, snapshot }) + } + + /** Deletes all snapshots and manifests for this session from storage. */ + async deleteSession(): Promise { + await this._storage.snapshot.deleteSession({ sessionId: this._sessionId }) + } + + /** Lists all available immutable snapshot IDs for the given agent target. */ + async listSnapshotIds(params: { target: LocalAgent; limit?: number; startAfter?: string }): Promise { + return this._storage.snapshot.listSnapshotIds({ + location: this._location(params.target), + ...(params.limit !== undefined && { limit: params.limit }), + ...(params.startAfter !== undefined && { startAfter: params.startAfter }), + }) + } + + /** Loads a snapshot from storage and restores it into the target. Returns false if no snapshot exists. */ + async restoreSnapshot(params: { target: LocalAgent; snapshotId?: string }): Promise + async restoreSnapshot(params: { + target: Graph | Swarm + state?: MultiAgentState + snapshotId?: string + }): Promise + async restoreSnapshot(params: { + target: LocalAgent | Graph | Swarm + state?: MultiAgentState + snapshotId?: string + }): Promise { + const isAgent = 'messages' in params.target + const location = isAgent + ? this._location(params.target as LocalAgent) + : this._multiAgentLocation(params.target as MultiAgent) + const snapshot = await this._storage.snapshot.loadSnapshot({ + location, + ...(params.snapshotId !== undefined && { snapshotId: params.snapshotId }), + }) + + if (!snapshot) return false + + if (isAgent) { + ;(params.target as LocalAgent).loadSnapshot(snapshot) + } else { + loadMultiAgentSnapshot(params.target as Graph | Swarm, snapshot, params.state) + } + return true + } + + /** Restores session state on agent initialization. */ + private async _onAgentInitialized(event: InitializedEvent): Promise { + const hadMessages = event.agent.messages.length > 0 + const restored = await this.restoreSnapshot({ target: event.agent }) + + if (restored && hadMessages) { + logger.warn( + `agent_id=<${event.agent.id}>, session_id=<${this._sessionId}> | agent had existing messages that were overwritten by session restore` + ) + } + + // Stateful models manage conversation history server-side, so any messages + // loaded from the snapshot would drift from the server's view on the next + // invocation. Duck-type the agent's `model` since `LocalAgent` does not + // expose it — `Agent` is the only implementor and always has one. + const statefulModel = (event.agent as { model?: { stateful?: boolean } }).model?.stateful + if (restored && statefulModel && event.agent.messages.length > 0) { + logger.debug( + `agent_id=<${event.agent.id}>, message_count=<${event.agent.messages.length}> | discarding restored messages for stateful model` + ) + event.agent.messages.length = 0 + } + } + + /** Saves latest on invocation and fires the snapshot trigger if configured. */ + private async _onAfterAgentInvocation(event: AfterInvocationEvent): Promise { + if (this._saveLatestOn === 'invocation') { + await this.saveSnapshot({ target: event.agent, isLatest: true }) + } + + if (this._snapshotTrigger?.({ agentData: event.agent })) { + await this._saveImmutableAndLatest(event.agent) + } + } + + private async _onMessageAdded(event: MessageAddedEvent): Promise { + await this.saveSnapshot({ target: event.agent, isLatest: true }) + } + + /** + * Saves snapshot when a message is redacted after a model call. + * Critical for ensuring guardrail redactions are persisted immediately. + */ + private async _onAfterModelCall(event: AfterModelCallEvent): Promise { + // Only save if there was a redaction + if (event.stopData?.redaction) { + await this.saveSnapshot({ target: event.agent, isLatest: true }) + } + } + + /** Captures one snapshot and writes it to both immutable history and snapshot_latest. */ + private async _saveImmutableAndLatest(agent: LocalAgent): Promise { + const snapshot = agent.takeSnapshot({ preset: 'session' }) + const snapshotId = uuidV7() + await Promise.all([ + this._storage.snapshot.saveSnapshot({ location: this._location(agent), snapshotId, isLatest: false, snapshot }), + this._storage.snapshot.saveSnapshot({ + location: this._location(agent), + snapshotId: 'latest', + isLatest: true, + snapshot, + }), + ]) + } + + // --------------------------------------------------------------------------- + // Multi-agent + // --------------------------------------------------------------------------- + + /** Initializes the multi-agent plugin by registering orchestrator lifecycle hooks. */ + public initMultiAgent(orchestrator: MultiAgent): void { + orchestrator.addHook(BeforeMultiAgentInvocationEvent, async (event) => { + await this._onBeforeMultiAgentInvocation(event) + }) + if (this._multiAgentSaveLatestOn === 'node') { + orchestrator.addHook(AfterNodeCallEvent, async (event) => { + await this._onAfterNodeCall(event) + }) + } + orchestrator.addHook(AfterMultiAgentInvocationEvent, async (event) => { + await this._onAfterMultiAgentInvocation(event) + }) + } + + private _multiAgentLocation(orchestrator: MultiAgent): SnapshotLocation { + return { sessionId: this._sessionId, scope: 'multiAgent', scopeId: orchestrator.id } + } + + /** Restores orchestrator state on first invocation (loads snapshot from storage once per orchestrator, then no-ops). */ + private async _onBeforeMultiAgentInvocation(event: BeforeMultiAgentInvocationEvent): Promise { + if (this._multiAgentRestoredIds.has(event.orchestrator.id)) return + this._multiAgentRestoredIds.add(event.orchestrator.id) + + const location = this._multiAgentLocation(event.orchestrator) + const snapshot = await this._storage.snapshot.loadSnapshot({ location }) + if (!snapshot) return + + loadMultiAgentSnapshot(event.orchestrator as Graph | Swarm, snapshot, event.state) + } + + /** Saves latest orchestrator snapshot after each node completes. */ + private async _onAfterNodeCall(event: AfterNodeCallEvent): Promise { + await this.saveSnapshot({ + target: event.orchestrator as Graph | Swarm, + state: event.state, + isLatest: true, + }) + } + + /** Saves latest orchestrator snapshot after invocation completes. */ + private async _onAfterMultiAgentInvocation(event: AfterMultiAgentInvocationEvent): Promise { + await this.saveSnapshot({ + target: event.orchestrator as Graph | Swarm, + state: event.state, + isLatest: true, + }) + } +} diff --git a/strands-ts/src/session/storage.ts b/strands-ts/src/session/storage.ts new file mode 100644 index 0000000000..e663f9343a --- /dev/null +++ b/strands-ts/src/session/storage.ts @@ -0,0 +1,93 @@ +import type { Scope, Snapshot, SnapshotManifest } from './types.js' + +/** + * Identifies the location of a snapshot within the storage hierarchy. + */ +export type SnapshotLocation = { + /** Session identifier */ + sessionId: string + /** Scope of the snapshot (agent or multi-agent) */ + scope: Scope + /** Scope-specific identifier (agent id or multi-agent id) */ + scopeId: string +} + +/** + * SessionStorage configuration for pluggable storage backends. + * Allows users to configure snapshot and transcript storage independently. + * + * @example + * ```typescript + * const storage: SessionStorage = { + * snapshot: new S3Storage({ bucket: 'my-bucket' }) + * } + * ``` + */ +export type SessionStorage = { + snapshot: SnapshotStorage + // TODO: Fast-follow - Transcript support +} + +/** + * Interface for snapshot persistence. + * Implementations provide storage backends (S3, filesystem, etc.). + * + * File layout convention: + * ``` + * sessions// + * scopes/ + * agent// + * snapshots/ + * snapshot_latest.json + * immutable_history/ + * snapshot_.json + * snapshot_.json + * ``` + */ +export interface SnapshotStorage { + /** + * Persists a snapshot to storage. + */ + saveSnapshot(params: { + location: SnapshotLocation + snapshotId: string + isLatest: boolean + snapshot: Snapshot + }): Promise + + /** + * Loads a snapshot from storage. + */ + loadSnapshot(params: { location: SnapshotLocation; snapshotId?: string }): Promise + + /** + * Lists all available immutable snapshot IDs for a session scope, sorted chronologically. + * Snapshot IDs are UUID v7 strings vended by the SDK — callers should treat them as opaque + * handles and never construct them manually. + * + * Typical pagination pattern: + * ```typescript + * const page1 = await storage.listSnapshotIds({ location }) + * const page2 = await storage.listSnapshotIds({ location, startAfter: page1.at(-1) }) + * ``` + * + * `limit` caps the number of returned IDs. `startAfter` is an exclusive cursor (the last ID + * from the previous page); it must be a UUID v7 obtained from a prior `listSnapshotIds` call. + */ + listSnapshotIds(params: { location: SnapshotLocation; limit?: number; startAfter?: string }): Promise + + /** + * Deletes all snapshots and directories belonging to the session ID. + */ + deleteSession(params: { sessionId: string }): Promise + + /** + * Loads the snapshot manifest. + */ + loadManifest(params: { location: SnapshotLocation }): Promise + + /** + * Saves the snapshot manifest. + */ + saveManifest(params: { location: SnapshotLocation; manifest: SnapshotManifest }): Promise +} diff --git a/strands-ts/src/session/types.ts b/strands-ts/src/session/types.ts new file mode 100644 index 0000000000..f9716b0445 --- /dev/null +++ b/strands-ts/src/session/types.ts @@ -0,0 +1,41 @@ +import type { LocalAgent } from '../types/agent.js' + +// Re-export Snapshot and Scope from the canonical location +export type { Snapshot, Scope } from '../types/snapshot.js' + +/** + * Manifest tracks snapshot metadata. + * Stored alongside snapshots to support versioning and future multi-agent patterns. + */ +export interface SnapshotManifest { + /** Schema version for forward/backward compatibility */ + schemaVersion: string + /** ISO 8601 timestamp of last manifest update */ + updatedAt: string +} + +/** + * Parameters passed to SnapshotTriggerCallback to determine when to create snapshots. + */ +export interface SnapshotTriggerParams { + /** Current agent data including messages and state */ + agentData: LocalAgent +} + +/** + * Callback function to determine when to create immutable snapshots. + * Called after each agent invocation to decide if a snapshot should be saved. + * + * @param params - Snapshot trigger parameters + * @returns true to create a snapshot, false to skip + * + * @example + * ```ts + * // Snapshot every 10 messages + * const trigger: SnapshotTriggerCallback = ({ agentData }) => agentData.messages.length % 10 === 0 + * + * // Snapshot when conversation exceeds 20 messages + * const trigger: SnapshotTriggerCallback = ({ agentData }) => agentData.messages.length > 20 + * ``` + */ +export type SnapshotTriggerCallback = (params: SnapshotTriggerParams) => boolean diff --git a/strands-ts/src/session/validation.ts b/strands-ts/src/session/validation.ts new file mode 100644 index 0000000000..a989f1389a --- /dev/null +++ b/strands-ts/src/session/validation.ts @@ -0,0 +1,28 @@ +/** + * Validates that an identifier contains only allowed characters. + * Allowed characters: lowercase letters (a-z), numbers (0-9), hyphens (-), and underscores (_) + * + * @param id - The identifier to validate + * @returns The validated identifier + * @throws Error if identifier contains invalid characters + */ +export function validateIdentifier(id: string): string { + const validPattern = /^[a-z0-9_-]+$/ + if (!validPattern.test(id)) { + throw new Error(`Identifier '${id}' can only contain lowercase letters, numbers, hyphens, and underscores`) + } + return id +} + +/** + * Validates that a string is a UUID v7. + * + * @param id - The string to validate + * @throws Error if the string is not a valid UUID v7 + */ +export function validateUuidV7(id: string): void { + const uuidV7Pattern = /^[0-9a-f]{8}-[0-9a-f]{4}-7[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$/i + if (!uuidV7Pattern.test(id)) { + throw new Error(`'${id}' is not a valid UUID v7 snapshot ID`) + } +} diff --git a/strands-ts/src/state-store.ts b/strands-ts/src/state-store.ts new file mode 100644 index 0000000000..773a402414 --- /dev/null +++ b/strands-ts/src/state-store.ts @@ -0,0 +1,163 @@ +import { deepCopy, deepCopyWithValidation, type JSONValue } from './types/json.js' +import { loadStateFromJSONSymbol, stateToJSONSymbol, type StateSerializable } from './types/serializable.js' + +/** + * Key-value storage for application state outside conversation context. + * State is not passed to the model during inference but is accessible + * by tools (via ToolContext) and application logic. + * + * All values are deep copied on get/set operations to prevent reference mutations. + * Values must be JSON serializable. + * + * @example + * ```typescript + * const state = new StateStore({ userId: 'user-123' }) + * state.set('sessionId', 'session-456') + * const userId = state.get('userId') // 'user-123' + * ``` + */ +export class StateStore implements StateSerializable { + private _state: Record + + /** + * Creates a new StateStore instance. + * + * @param initialState - Optional initial state values + * @throws Error if initialState is not JSON serializable + */ + constructor(initialState?: Record) { + if (initialState !== undefined) { + this._state = deepCopyWithValidation(initialState, 'initialState') as Record + } else { + this._state = {} + } + } + + /** + * Get a state value by key with optional type-safe property lookup. + * Returns a deep copy to prevent mutations. + * + * @typeParam TState - The complete state interface type + * @typeParam K - The property key (inferred from argument) + * @param key - Key to retrieve specific value + * @returns The value for the key, or undefined if key doesn't exist + * + * @example + * ```typescript + * // Typed usage + * const user = state.get('user') // { name: string; age: number } | undefined + * + * // Untyped usage + * const value = state.get('someKey') // JSONValue | undefined + * ``` + */ + get(key: K): TState[K] | undefined + get(key: string): JSONValue | undefined + get(key: string): JSONValue | Record | undefined { + if (key == null) { + throw new Error('key is required') + } + + const value = this._state[key] + if (value === undefined) { + return undefined + } + + // Return deep copy to prevent mutations + return deepCopy(value) + } + + /** + * Set a state value with optional type-safe property validation. + * Validates JSON serializability and stores a deep copy. + * + * @typeParam TState - The complete state interface type + * @typeParam K - The property key (inferred from argument) + * @param key - The key to set + * @param value - The value to store (must be JSON serializable) + * @throws Error if value is not JSON serializable + * + * @example + * ```typescript + * // Typed usage + * state.set('user', { name: 'Alice', age: 25 }) + * + * // Untyped usage + * state.set('someKey', { any: 'value' }) + * ``` + */ + set(key: K, value: TState[K]): void + set(key: string, value: unknown): void + set(key: string, value: unknown): void { + this._state[key] = deepCopyWithValidation(value, `value for key "${key}"`) + } + + /** + * Delete a state value by key with optional type-safe property validation. + * + * @typeParam TState - The complete state interface type + * @typeParam K - The property key (inferred from argument) + * @param key - The key to delete + * + * @example + * ```typescript + * // Typed usage + * state.delete('user') + * + * // Untyped usage + * state.delete('someKey') + * ``` + */ + delete(key: K): void + delete(key: string): void + delete(key: string): void { + delete this._state[key] + } + + /** + * Clear all state values. + */ + clear(): void { + this._state = {} + } + + /** + * Get a copy of all state as an object. + * + * @returns Deep copy of all state + */ + getAll(): Record { + return deepCopy(this._state) as Record + } + + /** + * Get all state keys. + * + * @returns Array of state keys + */ + keys(): string[] { + return Object.keys(this._state) + } + + /** + * Returns the serialized state as JSON value. + * + * @returns Deep copy of all state + */ + [stateToJSONSymbol](): JSONValue { + return deepCopy(this._state) as JSONValue + } + + /** + * Loads state from a previously serialized JSON value. + * + * @param json - The serialized state to load + */ + [loadStateFromJSONSymbol](json: JSONValue): void { + if (json !== null && typeof json === 'object' && !Array.isArray(json)) { + this._state = deepCopy(json) as Record + } else { + this._state = {} + } + } +} diff --git a/strands-ts/src/telemetry/__tests__/config.test.node.ts b/strands-ts/src/telemetry/__tests__/config.test.node.ts new file mode 100644 index 0000000000..4655533553 --- /dev/null +++ b/strands-ts/src/telemetry/__tests__/config.test.node.ts @@ -0,0 +1,229 @@ +import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest' +import { NodeTracerProvider } from '@opentelemetry/sdk-trace-node' +import { ConsoleSpanExporter } from '@opentelemetry/sdk-trace-base' +import { OTLPTraceExporter } from '@opentelemetry/exporter-trace-otlp-http' +import { findMetricValue } from '../../__fixtures__/metrics-helpers.js' + +vi.mock('@opentelemetry/exporter-trace-otlp-http', () => ({ + OTLPTraceExporter: vi.fn(), +})) + +vi.mock('@opentelemetry/sdk-trace-base', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + ConsoleSpanExporter: vi.fn(), + } +}) + +// resetModules clears the module cache so each test gets a fresh singleton. +// Tests use dynamic await import() to re-import after the reset. + +describe('setupTracer (node-specific)', () => { + const originalEnv = { ...process.env } + + beforeEach(() => { + vi.resetModules() + vi.clearAllMocks() + }) + + afterEach(() => { + process.env = { ...originalEnv } + }) + + describe('provider auto-detection', () => { + it('should use NodeTracerProvider by default for async context support', async () => { + const telemetry = await import('../index.js') + + const provider = telemetry.setupTracer() + + expect(provider).toBeInstanceOf(NodeTracerProvider) + }) + + it('should accept a custom NodeTracerProvider', async () => { + const telemetry = await import('../index.js') + const customProvider = new NodeTracerProvider() + + const provider = telemetry.setupTracer({ provider: customProvider }) + + expect(provider).toBe(customProvider) + }) + }) + + describe('exporter configuration', () => { + it('should add OTLP exporter when exporters.otlp is true', async () => { + const telemetry = await import('../index.js') + + telemetry.setupTracer({ exporters: { otlp: true } }) + + expect(OTLPTraceExporter).toHaveBeenCalled() + }) + + it('should add console exporter when exporters.console is true', async () => { + const telemetry = await import('../index.js') + + telemetry.setupTracer({ exporters: { console: true } }) + + expect(ConsoleSpanExporter).toHaveBeenCalled() + }) + + it('should add both exporters when both are true', async () => { + const telemetry = await import('../index.js') + + telemetry.setupTracer({ exporters: { otlp: true, console: true } }) + + expect(OTLPTraceExporter).toHaveBeenCalled() + expect(ConsoleSpanExporter).toHaveBeenCalled() + }) + + it('should add no exporters when both are false', async () => { + const telemetry = await import('../index.js') + + telemetry.setupTracer({ exporters: { otlp: false, console: false } }) + + expect(OTLPTraceExporter).not.toHaveBeenCalled() + expect(ConsoleSpanExporter).not.toHaveBeenCalled() + }) + + it('should add no exporters when exporters config is empty', async () => { + const telemetry = await import('../index.js') + + telemetry.setupTracer({}) + + expect(OTLPTraceExporter).not.toHaveBeenCalled() + expect(ConsoleSpanExporter).not.toHaveBeenCalled() + }) + }) + + describe('resource attributes from environment', () => { + it('should use OTEL_SERVICE_NAME when set', async () => { + process.env.OTEL_SERVICE_NAME = 'my-custom-service' + const telemetry = await import('../index.js') + + const provider = telemetry.setupTracer() + + expect(provider['_resource'].attributes['service.name']).toBe('my-custom-service') + }) + + it('should use OTEL_SERVICE_NAMESPACE when set', async () => { + process.env.OTEL_SERVICE_NAMESPACE = 'my-namespace' + const telemetry = await import('../index.js') + + const provider = telemetry.setupTracer() + + expect(provider['_resource'].attributes['service.namespace']).toBe('my-namespace') + }) + + it('should use OTEL_DEPLOYMENT_ENVIRONMENT when set', async () => { + process.env.OTEL_DEPLOYMENT_ENVIRONMENT = 'production' + const telemetry = await import('../index.js') + + const provider = telemetry.setupTracer() + + expect(provider['_resource'].attributes['deployment.environment']).toBe('production') + }) + + it('should merge OTEL_RESOURCE_ATTRIBUTES with defaults', async () => { + process.env.OTEL_RESOURCE_ATTRIBUTES = 'service.version=1.0.0,custom.team=platform' + const telemetry = await import('../index.js') + + const provider = telemetry.setupTracer() + + expect(provider['_resource'].attributes['service.version']).toBe('1.0.0') + expect(provider['_resource'].attributes['custom.team']).toBe('platform') + expect(provider['_resource'].attributes['service.name']).toBe('strands-agents') + }) + + it('should allow OTEL_RESOURCE_ATTRIBUTES to override defaults', async () => { + process.env.OTEL_RESOURCE_ATTRIBUTES = 'service.name=custom-service,deployment.environment=production' + const telemetry = await import('../index.js') + + const provider = telemetry.setupTracer() + + expect(provider['_resource'].attributes['service.name']).toBe('custom-service') + expect(provider['_resource'].attributes['deployment.environment']).toBe('production') + }) + }) +}) + +describe('setupMeter (node-specific)', () => { + const originalEnv = { ...process.env } + + beforeEach(() => { + vi.resetModules() + vi.clearAllMocks() + }) + + afterEach(() => { + process.env = { ...originalEnv } + }) + + describe('resource attributes from environment', () => { + it('should use OTEL_SERVICE_NAME when set', async () => { + process.env.OTEL_SERVICE_NAME = 'my-meter-service' + const { MeterProvider, InMemoryMetricExporter, PeriodicExportingMetricReader, AggregationTemporality } = + await import('@opentelemetry/sdk-metrics') + const { resourceFromAttributes } = await import('@opentelemetry/resources') + const telemetry = await import('../index.js') + + const exporter = new InMemoryMetricExporter(AggregationTemporality.CUMULATIVE) + const reader = new PeriodicExportingMetricReader({ exporter, exportIntervalMillis: 100 }) + const customProvider = new MeterProvider({ + resource: resourceFromAttributes({ 'service.name': 'my-meter-service' }), + readers: [reader], + }) + const provider = telemetry.setupMeter({ provider: customProvider }) + + provider.getMeter('test').createCounter('probe').add(1) + await provider.forceFlush() + + const resource = exporter.getMetrics().at(-1)?.resource + expect(resource?.attributes['service.name']).toBe('my-meter-service') + + await provider.shutdown() + }) + }) + + describe('global meter provider registration', () => { + it('returns a provider that produces real metrics via its own meter', async () => { + const { + MeterProvider: SdkMeterProvider, + InMemoryMetricExporter, + PeriodicExportingMetricReader, + AggregationTemporality, + } = await import('@opentelemetry/sdk-metrics') + const telemetry = await import('../index.js') + + const testExporter = new InMemoryMetricExporter(AggregationTemporality.CUMULATIVE) + const testReader = new PeriodicExportingMetricReader({ + exporter: testExporter, + exportIntervalMillis: 100, + }) + const testProvider = new SdkMeterProvider({ readers: [testReader] }) + + const provider = telemetry.setupMeter({ provider: testProvider }) + + const meter = provider.getMeter('test-registration') + const counter = meter.createCounter('test_registration_counter') + counter.add(42) + + await testProvider.forceFlush() + + expect(findMetricValue(testExporter.getMetrics(), 'test_registration_counter')).toBe(42) + + await testProvider.shutdown() + }) + }) + + describe('custom provider', () => { + it('accepts a custom MeterProvider', async () => { + const { MeterProvider } = await import('@opentelemetry/sdk-metrics') + const telemetry = await import('../index.js') + const customProvider = new MeterProvider() + + const provider = telemetry.setupMeter({ provider: customProvider }) + + expect(provider).toBe(customProvider) + }) + }) +}) diff --git a/strands-ts/src/telemetry/__tests__/config.test.ts b/strands-ts/src/telemetry/__tests__/config.test.ts new file mode 100644 index 0000000000..81295123c7 --- /dev/null +++ b/strands-ts/src/telemetry/__tests__/config.test.ts @@ -0,0 +1,84 @@ +import { describe, it, expect, beforeEach, vi } from 'vitest' + +// resetModules clears the module cache so each test gets a fresh singleton. +// Tests use dynamic await import() to re-import after the reset. + +describe('setupTracer', () => { + beforeEach(() => { + vi.resetModules() + vi.clearAllMocks() + }) + + describe('singleton behavior', () => { + it('should return the same provider instance when called twice', async () => { + const telemetry = await import('../index.js') + + const provider1 = telemetry.setupTracer({ exporters: { console: true } }) + const provider2 = telemetry.setupTracer({ exporters: { otlp: true } }) + + expect(provider1).toBe(provider2) + }) + + it('should log a warning when called twice', async () => { + const { logger } = await import('../../logging/index.js') + const warnSpy = vi.spyOn(logger, 'warn') + const telemetry = await import('../index.js') + + telemetry.setupTracer() + telemetry.setupTracer() + + expect(warnSpy).toHaveBeenCalledWith('tracer provider already initialized, returning existing provider') + }) + }) + + describe('resource attributes', () => { + it('should use strands-agents as default service name', async () => { + const telemetry = await import('../index.js') + + const provider = telemetry.setupTracer() + + expect(provider['_resource'].attributes['service.name']).toBe('strands-agents') + }) + + it('should include default resource attributes', async () => { + const telemetry = await import('../index.js') + + const provider = telemetry.setupTracer() + + expect(provider['_resource'].attributes['service.name']).toBe('strands-agents') + expect(provider['_resource'].attributes['service.namespace']).toBe('strands') + expect(provider['_resource'].attributes['deployment.environment']).toBe('development') + expect(provider['_resource'].attributes['telemetry.sdk.name']).toBe('opentelemetry') + expect(provider['_resource'].attributes['telemetry.sdk.language']).toBe('typescript') + }) + }) +}) + +describe('setupMeter', () => { + beforeEach(() => { + vi.resetModules() + vi.clearAllMocks() + }) + + describe('singleton behavior', () => { + it('returns the same provider instance when called twice', async () => { + const telemetry = await import('../index.js') + + const provider1 = telemetry.setupMeter({ exporters: { console: true } }) + const provider2 = telemetry.setupMeter({ exporters: { otlp: true } }) + + expect(provider1).toBe(provider2) + }) + + it('logs a warning when called twice', async () => { + const { logger } = await import('../../logging/index.js') + const warnSpy = vi.spyOn(logger, 'warn') + const telemetry = await import('../index.js') + + telemetry.setupMeter() + telemetry.setupMeter() + + expect(warnSpy).toHaveBeenCalledWith('meter provider already initialized, returning existing provider') + }) + }) +}) diff --git a/strands-ts/src/telemetry/__tests__/json.test.ts b/strands-ts/src/telemetry/__tests__/json.test.ts new file mode 100644 index 0000000000..e6f40f435a --- /dev/null +++ b/strands-ts/src/telemetry/__tests__/json.test.ts @@ -0,0 +1,105 @@ +import { describe, it, expect } from 'vitest' +import { jsonReplacer } from '../json.js' + +describe('jsonReplacer', () => { + describe('primitive values', () => { + it('serializes strings', () => { + expect(JSON.stringify('hello', jsonReplacer)).toBe('"hello"') + }) + + it('serializes numbers', () => { + expect(JSON.stringify(42, jsonReplacer)).toBe('42') + }) + + it('serializes booleans', () => { + expect(JSON.stringify(true, jsonReplacer)).toBe('true') + }) + + it('serializes null', () => { + expect(JSON.stringify(null, jsonReplacer)).toBe('null') + }) + }) + + describe('object values', () => { + it('serializes simple objects', () => { + const obj = { key: 'value', number: 42, bool: true } + expect(JSON.stringify(obj, jsonReplacer)).toBe(JSON.stringify(obj)) + }) + + it('serializes arrays', () => { + const arr = [1, 2, 3, 'test'] + expect(JSON.stringify(arr, jsonReplacer)).toBe(JSON.stringify(arr)) + }) + }) + + describe('special types', () => { + it('handles Date objects', () => { + const date = new Date('2024-01-01T00:00:00.000Z') + expect(JSON.stringify(date, jsonReplacer)).toBe('"2024-01-01T00:00:00.000Z"') + }) + + it('handles Date objects nested in objects', () => { + const date = new Date('2024-01-01T00:00:00.000Z') + expect(JSON.stringify({ timestamp: date, name: 'test' }, jsonReplacer)).toBe( + '{"timestamp":"2024-01-01T00:00:00.000Z","name":"test"}' + ) + }) + + it('replaces BigInt values', () => { + const bigint = BigInt(12345678901234567890n) + expect(JSON.stringify(bigint, jsonReplacer)).toBe('""') + }) + + it('replaces functions', () => { + const fn = (): string => 'test' + const result = JSON.parse(JSON.stringify({ callback: fn, name: 'test' }, jsonReplacer)) + expect(result).toStrictEqual({ callback: '', name: 'test' }) + }) + + it('replaces symbols', () => { + const result = JSON.parse(JSON.stringify({ sym: Symbol('test'), name: 'test' }, jsonReplacer)) + expect(result).toStrictEqual({ sym: '', name: 'test' }) + }) + + it('replaces ArrayBuffer values', () => { + const buffer = new ArrayBuffer(8) + const result = JSON.parse(JSON.stringify({ data: buffer, name: 'test' }, jsonReplacer)) + expect(result).toStrictEqual({ data: '', name: 'test' }) + }) + + it('replaces Uint8Array values', () => { + const bytes = new Uint8Array([1, 2, 3]) + const result = JSON.parse(JSON.stringify({ data: bytes, name: 'test' }, jsonReplacer)) + expect(result).toStrictEqual({ data: '', name: 'test' }) + }) + + it('handles mixed content in arrays', () => { + const fn = (): string => 'test' + const data = ['value', 42, fn, null, { key: true }] + const result = JSON.parse(JSON.stringify(data, jsonReplacer)) + expect(result).toStrictEqual(['value', 42, '', null, { key: true }]) + }) + + it('handles mixed content in nested objects', () => { + const fn = (): string => 'test' + const now = new Date('2025-01-01T12:00:00.000Z') + const data = { + metadata: { timestamp: now, version: '1.0', debug: { obj: fn } }, + content: [ + { type: 'text', value: 'Hello' }, + { type: 'binary', value: fn }, + ], + list: [fn, 1234, true, null, 'string'], + } + const result = JSON.parse(JSON.stringify(data, jsonReplacer)) + expect(result).toStrictEqual({ + metadata: { timestamp: '2025-01-01T12:00:00.000Z', version: '1.0', debug: { obj: '' } }, + content: [ + { type: 'text', value: 'Hello' }, + { type: 'binary', value: '' }, + ], + list: ['', 1234, true, null, 'string'], + }) + }) + }) +}) diff --git a/strands-ts/src/telemetry/__tests__/local-trace.test.ts b/strands-ts/src/telemetry/__tests__/local-trace.test.ts new file mode 100644 index 0000000000..ef00f0df66 --- /dev/null +++ b/strands-ts/src/telemetry/__tests__/local-trace.test.ts @@ -0,0 +1,197 @@ +import { describe, it, expect } from 'vitest' +import { AgentTrace } from '../tracer.js' +import { Message, TextBlock } from '../../types/messages.js' + +describe('LocalTrace', () => { + describe('constructor', () => { + it('generates a unique id in UUID format', () => { + const trace1 = new AgentTrace('test') + const trace2 = new AgentTrace('test') + + const uuidRegex = /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/ + expect(trace1.id).toMatch(uuidRegex) + expect(trace2.id).toMatch(uuidRegex) + expect(trace1.id).not.toBe(trace2.id) + }) + + it('sets name and defaults', () => { + const trace = new AgentTrace('Cycle 1') + + expect(trace.name).toBe('Cycle 1') + expect(trace.parentId).toBeNull() + expect(trace.endTime).toBeNull() + expect(trace.duration).toBe(0) + expect(trace.children).toEqual([]) + expect(trace.metadata).toEqual({}) + expect(trace.message).toBeNull() + }) + + it('uses current time as default startTime', () => { + const before = Date.now() + const trace = new AgentTrace('test') + const after = Date.now() + + expect(trace.startTime).toBeGreaterThanOrEqual(before) + expect(trace.startTime).toBeLessThanOrEqual(after) + }) + + it('accepts a custom startTime', () => { + const trace = new AgentTrace('test', { startTime: 1000 }) + + expect(trace.startTime).toBe(1000) + }) + }) + + describe('parent-child relationships', () => { + it('adds child to parent when parent is provided', () => { + const parent = new AgentTrace('parent') + const child = new AgentTrace('child', { parent }) + + expect(parent.children).toHaveLength(1) + expect(parent.children[0]).toBe(child) + expect(child.parentId).toBe(parent.id) + }) + + it('supports multiple children', () => { + const parent = new AgentTrace('parent') + const child1 = new AgentTrace('child1', { parent }) + const child2 = new AgentTrace('child2', { parent }) + + expect(parent.children).toHaveLength(2) + expect(parent.children[0]).toBe(child1) + expect(parent.children[1]).toBe(child2) + }) + + it('sets parentId to null when no parent is provided', () => { + const trace = new AgentTrace('root') + + expect(trace.parentId).toBeNull() + }) + + it('builds a three-level hierarchy', () => { + const root = new AgentTrace('Cycle 1', { startTime: 1000 }) + const model = new AgentTrace('stream_messages', { parent: root, startTime: 1001 }) + const tool = new AgentTrace('Tool: calc', { parent: root, startTime: 1050 }) + + expect(root.children).toHaveLength(2) + expect(root.children[0]!.name).toBe('stream_messages') + expect(root.children[0]!.parentId).toBe(root.id) + expect(root.children[1]!.name).toBe('Tool: calc') + expect(root.children[1]!.parentId).toBe(root.id) + expect(model.parentId).toBe(root.id) + expect(tool.parentId).toBe(root.id) + }) + }) + + describe('end', () => { + it('computes duration from startTime to endTime', () => { + const trace = new AgentTrace('test', { startTime: 1000 }) + + trace.end(1500) + + expect(trace.endTime).toBe(1500) + expect(trace.duration).toBe(500) + }) + + it('uses current time when no endTime is provided', () => { + const before = Date.now() + const trace = new AgentTrace('test') + + trace.end() + + expect(trace.endTime).toBeGreaterThanOrEqual(before) + expect(trace.duration).toBeGreaterThanOrEqual(0) + }) + }) + + describe('metadata and message', () => { + it('stores cycle metadata', () => { + const trace = new AgentTrace('Cycle 1') + + trace.metadata.cycleId = 'cycle-1' + + expect(trace.metadata).toStrictEqual({ cycleId: 'cycle-1' }) + }) + + it('stores tool metadata', () => { + const trace = new AgentTrace('Tool: calc') + + trace.metadata.toolUseId = 'tool-1' + trace.metadata.toolName = 'calc' + + expect(trace.metadata).toStrictEqual({ toolUseId: 'tool-1', toolName: 'calc' }) + }) + + it('stores a message with role and content', () => { + const msg = new Message({ role: 'assistant', content: [new TextBlock('hello')] }) + const trace = new AgentTrace('stream_messages') + + trace.message = msg + + expect(trace.message.role).toBe('assistant') + expect(trace.message.content).toStrictEqual([new TextBlock('hello')]) + }) + }) + + describe('toJSON', () => { + it('returns complete data for a default trace', () => { + const trace = new AgentTrace('Cycle 1', { startTime: 1000 }) + + const json = trace.toJSON() + + expect(json).toStrictEqual({ + id: trace.id, + name: 'Cycle 1', + parentId: null, + startTime: 1000, + endTime: null, + duration: 0, + children: [], + metadata: {}, + message: null, + }) + }) + + it('serializes a hierarchy with children and metadata', () => { + const root = new AgentTrace('Cycle 1', { startTime: 1000 }) + root.metadata.cycleId = 'cycle-1' + const child = new AgentTrace('stream_messages', { parent: root, startTime: 1001 }) + child.end(1100) + root.end(1200) + + const json = root.toJSON() + + expect(json.name).toBe('Cycle 1') + expect(json.metadata.cycleId).toBe('cycle-1') + expect(json.duration).toBe(200) + expect(json.children).toHaveLength(1) + expect(json.children[0]).toStrictEqual({ + id: child.id, + name: 'stream_messages', + parentId: root.id, + startTime: 1001, + endTime: 1100, + duration: 99, + children: [], + metadata: {}, + message: null, + }) + }) + + it('serializes tool metadata correctly', () => { + const toolTrace = new AgentTrace('Tool: calc', { startTime: 1000 }) + toolTrace.metadata.toolUseId = 'tool-123' + toolTrace.metadata.toolName = 'calc' + toolTrace.end(1500) + + const json = toolTrace.toJSON() + + expect(json.name).toBe('Tool: calc') + expect(json.metadata).toStrictEqual({ + toolUseId: 'tool-123', + toolName: 'calc', + }) + expect(json.duration).toBe(500) + }) + }) +}) diff --git a/strands-ts/src/telemetry/__tests__/meter.test.ts b/strands-ts/src/telemetry/__tests__/meter.test.ts new file mode 100644 index 0000000000..a4824f9a0a --- /dev/null +++ b/strands-ts/src/telemetry/__tests__/meter.test.ts @@ -0,0 +1,832 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { metrics as otelMetrics, type Meter as OtelMeter } from '@opentelemetry/api' +import { Meter, AgentMetrics } from '../meter.js' +import { MockMeter } from '../../__fixtures__/mock-meter.js' +import type { ToolUse } from '../../tools/types.js' + +describe('Meter', () => { + const makeTool = (name: string, toolUseId: string): ToolUse => ({ + name, + toolUseId, + input: {}, + }) + + let meter: Meter + + beforeEach(() => { + meter = new Meter() + }) + + describe('metrics getter', () => { + it('returns an AgentMetrics instance', () => { + expect(meter.metrics).toBeInstanceOf(AgentMetrics) + }) + + it('returns zeroed snapshot for fresh instance', () => { + const snapshot = meter.metrics + expect(snapshot).toStrictEqual( + new AgentMetrics({ + cycleCount: 0, + toolMetrics: {}, + agentInvocations: [], + accumulatedUsage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 }, + accumulatedMetrics: { latencyMs: 0 }, + }) + ) + }) + + it('returns complete snapshot after a realistic agent execution', () => { + vi.useFakeTimers() + vi.setSystemTime(100_000) + + meter.startNewInvocation() + + const c1 = meter.startCycle() + meter.updateCycle({ + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + metrics: { latencyMs: 100 }, + }) + meter.endToolCall({ + tool: makeTool('search', 'tid-1'), + duration: 0.5, + success: true, + }) + vi.setSystemTime(103_000) + meter.endCycle(c1.startTime) + + vi.setSystemTime(200_000) + const c2 = meter.startCycle() + meter.updateCycle({ + type: 'modelMetadataEvent', + usage: { inputTokens: 20, outputTokens: 10, totalTokens: 30 }, + metrics: { latencyMs: 250 }, + }) + meter.endToolCall({ + tool: makeTool('search', 'tid-2'), + duration: 1.5, + success: false, + }) + vi.setSystemTime(205_000) + meter.endCycle(c2.startTime) + + const snapshot = meter.metrics + + expect(snapshot.cycleCount).toBe(2) + expect(snapshot.totalDuration).toBe(8000) + expect(snapshot.accumulatedUsage).toStrictEqual({ inputTokens: 30, outputTokens: 15, totalTokens: 45 }) + expect(snapshot.accumulatedMetrics).toStrictEqual({ latencyMs: 350 }) + expect(snapshot.toolMetrics).toStrictEqual({ + search: { + callCount: 2, + successCount: 1, + errorCount: 1, + totalTime: 2.0, + }, + }) + expect(snapshot.agentInvocations).toStrictEqual([ + { + usage: { inputTokens: 30, outputTokens: 15, totalTokens: 45 }, + cycles: [ + { cycleId: 'cycle-1', duration: 3000, usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } }, + { cycleId: 'cycle-2', duration: 5000, usage: { inputTokens: 20, outputTokens: 10, totalTokens: 30 } }, + ], + }, + ]) + + vi.useRealTimers() + }) + + it('tracks multiple invocations independently', () => { + meter.startNewInvocation() + meter.startCycle() + meter.updateCycle({ + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + }) + + meter.startNewInvocation() + meter.startCycle() + meter.updateCycle({ + type: 'modelMetadataEvent', + usage: { inputTokens: 20, outputTokens: 10, totalTokens: 30 }, + }) + + expect(meter.metrics.agentInvocations).toStrictEqual([ + { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + cycles: [{ cycleId: 'cycle-1', duration: 0, usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } }], + }, + { + usage: { inputTokens: 20, outputTokens: 10, totalTokens: 30 }, + cycles: [{ cycleId: 'cycle-2', duration: 0, usage: { inputTokens: 20, outputTokens: 10, totalTokens: 30 } }], + }, + ]) + }) + }) + + describe('startNewInvocation', () => { + it('appends an invocation with empty cycles and zeroed usage', () => { + meter.startNewInvocation() + + expect(meter.metrics.agentInvocations).toStrictEqual([ + { cycles: [], usage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 } }, + ]) + }) + + it('latestAgentInvocation returns the most recently added invocation', () => { + meter.startNewInvocation() + meter.startNewInvocation() + + const snapshot = meter.metrics + expect(snapshot.agentInvocations).toHaveLength(2) + expect(snapshot.latestAgentInvocation).toBe(snapshot.agentInvocations[1]) + }) + + it('evicts oldest entries when exceeding max history cap', () => { + for (let i = 0; i < 60; i++) { + meter.startNewInvocation() + meter.startCycle() + } + + const invocations = meter.metrics.agentInvocations + expect(invocations).toHaveLength(50) + // First retained entry is the 11th invocation (oldest 10 were evicted) + expect(invocations[0]!.cycles[0]!.cycleId).toBe('cycle-11') + expect(invocations[49]!.cycles[0]!.cycleId).toBe('cycle-60') + }) + }) + + describe('startCycle', () => { + it('returns cycle id and start time', () => { + vi.useFakeTimers() + vi.setSystemTime(100_000) + + const result = meter.startCycle() + + expect(result).toStrictEqual({ + cycleId: 'cycle-1', + startTime: 100_000, + }) + expect(meter.metrics.cycleCount).toBe(1) + vi.useRealTimers() + }) + + it('adds cycle entry to the latest invocation', () => { + meter.startNewInvocation() + meter.startCycle() + + expect(meter.metrics.latestAgentInvocation!.cycles).toStrictEqual([ + { cycleId: 'cycle-1', duration: 0, usage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 } }, + ]) + }) + + it('does not fail when no invocation exists', () => { + const result = meter.startCycle() + + expect(result.cycleId).toBe('cycle-1') + expect(meter.metrics.agentInvocations).toStrictEqual([]) + }) + }) + + describe('endCycle', () => { + it('records duration on the latest cycle', () => { + vi.useFakeTimers() + vi.setSystemTime(200_000) + + meter.startNewInvocation() + meter.startCycle() + meter.endCycle(100_000) + + expect(meter.metrics.latestAgentInvocation!.cycles[0]!.duration).toBe(100_000) + vi.useRealTimers() + }) + + it('does not fail when no invocation exists', () => { + vi.useFakeTimers() + vi.setSystemTime(200_000) + + meter.startCycle() + + expect(() => meter.endCycle(100_000)).not.toThrow() + expect(meter.metrics.agentInvocations).toStrictEqual([]) + vi.useRealTimers() + }) + + it('does not fail when invocation has no cycles', () => { + vi.useFakeTimers() + vi.setSystemTime(200_000) + + meter.startNewInvocation() + + expect(() => meter.endCycle(100_000)).not.toThrow() + expect(meter.metrics.latestAgentInvocation!.cycles).toStrictEqual([]) + vi.useRealTimers() + }) + }) + + describe('endToolCall', () => { + it('records success', () => { + meter.endToolCall({ + tool: makeTool('myTool', 'id-1'), + duration: 1.5, + success: true, + }) + + expect(meter.metrics.toolMetrics).toStrictEqual({ + myTool: { callCount: 1, successCount: 1, errorCount: 0, totalTime: 1.5 }, + }) + }) + + it('records failure', () => { + meter.endToolCall({ + tool: makeTool('myTool', 'id-1'), + duration: 0.5, + success: false, + }) + + expect(meter.metrics.toolMetrics).toStrictEqual({ + myTool: { callCount: 1, successCount: 0, errorCount: 1, totalTime: 0.5 }, + }) + }) + + it('accumulates across multiple calls to the same tool', () => { + meter.endToolCall({ + tool: makeTool('myTool', 'id-1'), + duration: 1.0, + success: true, + }) + meter.endToolCall({ + tool: makeTool('myTool', 'id-2'), + duration: 2.0, + success: false, + }) + + expect(meter.metrics.toolMetrics).toStrictEqual({ + myTool: { callCount: 2, successCount: 1, errorCount: 1, totalTime: 3.0 }, + }) + }) + + it('tracks different tools independently', () => { + meter.endToolCall({ + tool: makeTool('toolA', 'id-1'), + duration: 1.0, + success: true, + }) + meter.endToolCall({ + tool: makeTool('toolB', 'id-2'), + duration: 2.0, + success: false, + }) + + expect(meter.metrics.toolMetrics).toStrictEqual({ + toolA: { callCount: 1, successCount: 1, errorCount: 0, totalTime: 1.0 }, + toolB: { callCount: 1, successCount: 0, errorCount: 1, totalTime: 2.0 }, + }) + }) + }) + + describe('latestContextSize', () => { + it('is undefined when no invocations have occurred', () => { + expect(meter.metrics.latestContextSize).toBeUndefined() + }) + + it('returns the most recent inputTokens after a model call', () => { + meter.updateCycle({ + type: 'modelMetadataEvent', + usage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 }, + }) + + expect(meter.metrics.latestContextSize).toBe(100) + }) + + it('updates across multiple cycles', () => { + meter.updateCycle({ + type: 'modelMetadataEvent', + usage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 }, + }) + meter.updateCycle({ + type: 'modelMetadataEvent', + usage: { inputTokens: 200, outputTokens: 80, totalTokens: 280 }, + }) + + expect(meter.metrics.latestContextSize).toBe(200) + }) + + it('updates across multiple invocations', () => { + meter.startNewInvocation() + meter.startCycle() + meter.updateCycle({ + type: 'modelMetadataEvent', + usage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 }, + }) + + meter.startNewInvocation() + meter.startCycle() + meter.updateCycle({ + type: 'modelMetadataEvent', + usage: { inputTokens: 300, outputTokens: 100, totalTokens: 400 }, + }) + + expect(meter.metrics.latestContextSize).toBe(300) + }) + + it('remains undefined when metadata has no usage', () => { + meter.updateCycle({ + type: 'modelMetadataEvent', + metrics: { latencyMs: 100 }, + }) + + expect(meter.metrics.latestContextSize).toBeUndefined() + }) + + it('remains undefined when updateCycle is called with undefined', () => { + meter.updateCycle(undefined) + + expect(meter.metrics.latestContextSize).toBeUndefined() + }) + }) + + describe('projectedContextSize', () => { + it('is undefined when no invocations have occurred', () => { + expect(meter.metrics.projectedContextSize).toBeUndefined() + }) + + it('returns inputTokens + outputTokens after a model call', () => { + meter.updateCycle({ + type: 'modelMetadataEvent', + usage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 }, + }) + + expect(meter.metrics.projectedContextSize).toBe(150) + }) + + it('updates across multiple cycles', () => { + meter.updateCycle({ + type: 'modelMetadataEvent', + usage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 }, + }) + meter.updateCycle({ + type: 'modelMetadataEvent', + usage: { inputTokens: 200, outputTokens: 80, totalTokens: 280 }, + }) + + expect(meter.metrics.projectedContextSize).toBe(280) + }) + }) + + describe('updateCycle', () => { + it('accumulates usage and latency from metadata', () => { + meter.updateCycle({ + type: 'modelMetadataEvent', + usage: { inputTokens: 5, outputTokens: 3, totalTokens: 8 }, + metrics: { latencyMs: 100 }, + }) + meter.updateCycle({ + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 7, totalTokens: 17 }, + metrics: { latencyMs: 200 }, + }) + + expect(meter.metrics.accumulatedUsage).toStrictEqual({ + inputTokens: 15, + outputTokens: 10, + totalTokens: 25, + }) + expect(meter.metrics.accumulatedMetrics).toStrictEqual({ latencyMs: 300 }) + }) + + it('accumulates cache tokens across calls', () => { + meter.updateCycle({ + type: 'modelMetadataEvent', + usage: { + inputTokens: 10, + outputTokens: 5, + totalTokens: 15, + cacheReadInputTokens: 3, + cacheWriteInputTokens: 2, + }, + }) + meter.updateCycle({ + type: 'modelMetadataEvent', + usage: { + inputTokens: 5, + outputTokens: 2, + totalTokens: 7, + cacheReadInputTokens: 4, + }, + }) + + expect(meter.metrics.accumulatedUsage).toStrictEqual({ + inputTokens: 15, + outputTokens: 7, + totalTokens: 22, + cacheReadInputTokens: 7, + cacheWriteInputTokens: 2, + }) + }) + + it('handles usage-only metadata', () => { + meter.updateCycle({ + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + }) + + expect(meter.metrics.accumulatedUsage).toStrictEqual({ + inputTokens: 10, + outputTokens: 5, + totalTokens: 15, + }) + expect(meter.metrics.accumulatedMetrics).toStrictEqual({ latencyMs: 0 }) + }) + + it('handles metrics-only metadata', () => { + meter.updateCycle({ + type: 'modelMetadataEvent', + metrics: { latencyMs: 250 }, + }) + + expect(meter.metrics.accumulatedUsage).toStrictEqual({ + inputTokens: 0, + outputTokens: 0, + totalTokens: 0, + }) + expect(meter.metrics.accumulatedMetrics).toStrictEqual({ latencyMs: 250 }) + }) + + it('propagates usage to invocation and current cycle', () => { + meter.startNewInvocation() + meter.startCycle() + + meter.updateCycle({ + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + }) + + const invocation = meter.metrics.latestAgentInvocation! + expect(invocation).toStrictEqual({ + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + cycles: [{ cycleId: 'cycle-1', duration: 0, usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } }], + }) + }) + + it('is a no-op when metadata is undefined', () => { + meter.updateCycle(undefined) + + expect(meter.metrics.accumulatedUsage).toStrictEqual({ + inputTokens: 0, + outputTokens: 0, + totalTokens: 0, + }) + expect(meter.metrics.accumulatedMetrics).toStrictEqual({ latencyMs: 0 }) + }) + + it('is a no-op when called with no arguments', () => { + meter.updateCycle() + + expect(meter.metrics.accumulatedUsage).toStrictEqual({ + inputTokens: 0, + outputTokens: 0, + totalTokens: 0, + }) + expect(meter.metrics.accumulatedMetrics).toStrictEqual({ latencyMs: 0 }) + }) + + it('is a no-op when metadata has neither usage nor metrics', () => { + meter.updateCycle({ type: 'modelMetadataEvent' }) + + expect(meter.metrics.accumulatedUsage).toStrictEqual({ + inputTokens: 0, + outputTokens: 0, + totalTokens: 0, + }) + expect(meter.metrics.accumulatedMetrics).toStrictEqual({ latencyMs: 0 }) + }) + + it('does not fail when no invocation exists', () => { + expect(() => { + meter.updateCycle({ + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + }) + }).not.toThrow() + + expect(meter.metrics.accumulatedUsage).toStrictEqual({ + inputTokens: 10, + outputTokens: 5, + totalTokens: 15, + }) + }) + }) + + describe('OTEL instrument emission', () => { + let mockMeter: MockMeter + + beforeEach(() => { + mockMeter = new MockMeter() + vi.spyOn(otelMetrics, 'getMeter').mockReturnValue(mockMeter as unknown as OtelMeter) + }) + + it('emits invocation counter on startNewInvocation', () => { + const m = new Meter() + + m.startNewInvocation() + + expect(mockMeter.getCounter('gen_ai.agent.invocation.count')?.sum).toBe(1) + }) + + it('emits cycle counter on startCycle', () => { + const m = new Meter() + + m.startCycle() + + expect(mockMeter.getCounter('gen_ai.agent.cycle.count')?.sum).toBe(1) + }) + + it('emits cycle duration histogram on endCycle', () => { + vi.useFakeTimers() + vi.setSystemTime(5000) + const m = new Meter() + + m.endCycle(3000) + + expect(mockMeter.getHistogram('gen_ai.agent.cycle.duration')?.sum).toBe(2000) + vi.useRealTimers() + }) + + it('emits tool call counter and duration on successful endToolCall', () => { + const m = new Meter() + + m.endToolCall({ tool: makeTool('search', 'id-1'), duration: 150, success: true }) + + expect(mockMeter.getCounter('gen_ai.agent.tool.call.count')?.dataPoints).toStrictEqual([ + { value: 1, attributes: { 'gen_ai.tool.name': 'search' } }, + ]) + expect(mockMeter.getHistogram('gen_ai.agent.tool.duration')?.dataPoints).toStrictEqual([ + { value: 150, attributes: { 'gen_ai.tool.name': 'search' } }, + ]) + expect(mockMeter.getCounter('gen_ai.agent.tool.error.count')?.dataPoints).toStrictEqual([]) + }) + + it('emits tool call counter, error counter, and duration on failed endToolCall', () => { + const m = new Meter() + + m.endToolCall({ tool: makeTool('search', 'id-1'), duration: 50, success: false }) + + expect(mockMeter.getCounter('gen_ai.agent.tool.call.count')?.dataPoints).toStrictEqual([ + { value: 1, attributes: { 'gen_ai.tool.name': 'search' } }, + ]) + expect(mockMeter.getCounter('gen_ai.agent.tool.error.count')?.dataPoints).toStrictEqual([ + { value: 1, attributes: { 'gen_ai.tool.name': 'search' } }, + ]) + expect(mockMeter.getHistogram('gen_ai.agent.tool.duration')?.dataPoints).toStrictEqual([ + { value: 50, attributes: { 'gen_ai.tool.name': 'search' } }, + ]) + }) + + it('emits input token counter, output token counter, and model latency on updateCycle', () => { + const m = new Meter() + + m.updateCycle({ + type: 'modelMetadataEvent', + usage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 }, + metrics: { latencyMs: 200 }, + }) + + expect(mockMeter.getCounter('gen_ai.agent.tokens.input')?.sum).toBe(100) + expect(mockMeter.getCounter('gen_ai.agent.tokens.output')?.sum).toBe(50) + expect(mockMeter.getHistogram('gen_ai.agent.model.latency')?.sum).toBe(200) + }) + + it('does not emit token counters or latency when updateCycle has no usage or metrics', () => { + const m = new Meter() + + m.updateCycle({ type: 'modelMetadataEvent' }) + + expect(mockMeter.getCounter('gen_ai.agent.tokens.input')?.dataPoints).toStrictEqual([]) + expect(mockMeter.getCounter('gen_ai.agent.tokens.output')?.dataPoints).toStrictEqual([]) + expect(mockMeter.getHistogram('gen_ai.agent.model.latency')?.dataPoints).toStrictEqual([]) + }) + + it('does not emit any OTEL instruments when updateCycle is called with undefined', () => { + const m = new Meter() + + m.updateCycle(undefined) + + expect(mockMeter.getCounter('gen_ai.agent.tokens.input')?.dataPoints).toStrictEqual([]) + expect(mockMeter.getCounter('gen_ai.agent.tokens.output')?.dataPoints).toStrictEqual([]) + expect(mockMeter.getHistogram('gen_ai.agent.model.latency')?.dataPoints).toStrictEqual([]) + }) + + it('emits time-to-first-token histogram in seconds when timeToFirstByteMs is provided', () => { + const m = new Meter() + + m.updateCycle({ + type: 'modelMetadataEvent', + metrics: { latencyMs: 500, timeToFirstByteMs: 150 }, + }) + + expect(mockMeter.getHistogram('gen_ai.server.time_to_first_token')?.sum).toBeCloseTo(0.15) + }) + + it('does not emit time-to-first-token histogram when timeToFirstByteMs is undefined', () => { + const m = new Meter() + + m.updateCycle({ + type: 'modelMetadataEvent', + metrics: { latencyMs: 500 }, + }) + + expect(mockMeter.getHistogram('gen_ai.server.time_to_first_token')?.dataPoints).toStrictEqual([]) + }) + + it('does not emit time-to-first-token histogram when timeToFirstByteMs is zero', () => { + const m = new Meter() + + m.updateCycle({ + type: 'modelMetadataEvent', + metrics: { latencyMs: 500, timeToFirstByteMs: 0 }, + }) + + expect(mockMeter.getHistogram('gen_ai.server.time_to_first_token')?.dataPoints).toStrictEqual([]) + }) + }) +}) + +describe('AgentMetrics', () => { + describe('toJSON', () => { + it('returns complete zeroed data for default instance', () => { + const metrics = new AgentMetrics() + expect(metrics.toJSON()).toStrictEqual({ + cycleCount: 0, + accumulatedUsage: { inputTokens: 0, outputTokens: 0, totalTokens: 0 }, + accumulatedMetrics: { latencyMs: 0 }, + agentInvocations: [], + toolMetrics: {}, + totalDuration: 0, + }) + }) + + it('returns data from provided metrics', () => { + const metrics = new AgentMetrics({ + cycleCount: 2, + totalDuration: 8000, + toolMetrics: { + search: { callCount: 2, successCount: 1, errorCount: 1, totalTime: 2.0 }, + }, + accumulatedUsage: { inputTokens: 30, outputTokens: 15, totalTokens: 45 }, + accumulatedMetrics: { latencyMs: 350 }, + agentInvocations: [ + { + usage: { inputTokens: 30, outputTokens: 15, totalTokens: 45 }, + cycles: [ + { cycleId: 'cycle-1', duration: 3000, usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } }, + { cycleId: 'cycle-2', duration: 5000, usage: { inputTokens: 20, outputTokens: 10, totalTokens: 30 } }, + ], + }, + ], + }) + + expect(metrics.toJSON()).toStrictEqual({ + cycleCount: 2, + accumulatedUsage: { inputTokens: 30, outputTokens: 15, totalTokens: 45 }, + accumulatedMetrics: { latencyMs: 350 }, + agentInvocations: [ + { + usage: { inputTokens: 30, outputTokens: 15, totalTokens: 45 }, + cycles: [ + { cycleId: 'cycle-1', duration: 3000, usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } }, + { cycleId: 'cycle-2', duration: 5000, usage: { inputTokens: 20, outputTokens: 10, totalTokens: 30 } }, + ], + }, + ], + toolMetrics: { + search: { callCount: 2, successCount: 1, errorCount: 1, totalTime: 2.0 }, + }, + totalDuration: 8000, + }) + }) + }) + + describe('toJSON with latestContextSize', () => { + it('includes latestContextSize when set', () => { + const metrics = new AgentMetrics({ latestContextSize: 42 }) + expect(metrics.toJSON()).toHaveProperty('latestContextSize', 42) + }) + + it('omits latestContextSize when undefined', () => { + const metrics = new AgentMetrics() + expect(metrics.toJSON()).not.toHaveProperty('latestContextSize') + }) + }) + + describe('toJSON roundtrip', () => { + it('reconstructs equivalent AgentMetrics from serialized data', () => { + const original = new AgentMetrics({ + cycleCount: 3, + accumulatedUsage: { inputTokens: 50, outputTokens: 25, totalTokens: 75 }, + accumulatedMetrics: { latencyMs: 500 }, + agentInvocations: [ + { + usage: { inputTokens: 50, outputTokens: 25, totalTokens: 75 }, + cycles: [ + { cycleId: 'cycle-1', duration: 1000, usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } }, + { cycleId: 'cycle-2', duration: 2000, usage: { inputTokens: 20, outputTokens: 10, totalTokens: 30 } }, + { cycleId: 'cycle-3', duration: 3000, usage: { inputTokens: 20, outputTokens: 10, totalTokens: 30 } }, + ], + }, + ], + toolMetrics: { + search: { callCount: 2, successCount: 2, errorCount: 0, totalTime: 1.5 }, + calc: { callCount: 1, successCount: 0, errorCount: 1, totalTime: 0.3 }, + }, + }) + + const json = JSON.stringify(original) + const restored = new AgentMetrics(JSON.parse(json)) + + expect(restored.toJSON()).toStrictEqual(original.toJSON()) + }) + }) + + describe('computed getters', () => { + it('latestAgentInvocation returns the last invocation', () => { + const metrics = new AgentMetrics({ + agentInvocations: [ + { cycles: [], usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 } }, + { cycles: [], usage: { inputTokens: 20, outputTokens: 10, totalTokens: 30 } }, + ], + }) + + expect(metrics.latestAgentInvocation).toBe(metrics.agentInvocations[1]) + }) + + it('latestAgentInvocation returns undefined when empty', () => { + const metrics = new AgentMetrics() + expect(metrics.latestAgentInvocation).toBeUndefined() + }) + + it('accumulatedData returns usage and metrics together', () => { + const metrics = new AgentMetrics({ + accumulatedUsage: { inputTokens: 30, outputTokens: 15, totalTokens: 45 }, + accumulatedMetrics: { latencyMs: 350 }, + }) + + expect(metrics.accumulatedData).toStrictEqual({ + usage: { inputTokens: 30, outputTokens: 15, totalTokens: 45 }, + metrics: { latencyMs: 350 }, + }) + }) + + it('averageCycleTime computes average', () => { + const metrics = new AgentMetrics({ + cycleCount: 2, + totalDuration: 8000, + }) + expect(metrics.averageCycleTime).toBe(4000) + }) + + it('averageCycleTime returns 0 when no cycles', () => { + const metrics = new AgentMetrics() + expect(metrics.averageCycleTime).toBe(0) + }) + + it('toolUsage adds computed averageTime and successRate', () => { + const metrics = new AgentMetrics({ + toolMetrics: { + search: { callCount: 2, successCount: 1, errorCount: 1, totalTime: 2.0 }, + }, + }) + + expect(metrics.toolUsage).toStrictEqual({ + search: { + callCount: 2, + successCount: 1, + errorCount: 1, + totalTime: 2.0, + averageTime: 1.0, + successRate: 0.5, + }, + }) + }) + + it('toolUsage returns 0 for averageTime and successRate when callCount is 0', () => { + const metrics = new AgentMetrics({ + toolMetrics: { + broken: { callCount: 0, successCount: 0, errorCount: 0, totalTime: 0 }, + }, + }) + + expect(metrics.toolUsage).toStrictEqual({ + broken: { + callCount: 0, + successCount: 0, + errorCount: 0, + totalTime: 0, + averageTime: 0, + successRate: 0, + }, + }) + }) + }) +}) diff --git a/strands-ts/src/telemetry/__tests__/tracer.test.node.ts b/strands-ts/src/telemetry/__tests__/tracer.test.node.ts new file mode 100644 index 0000000000..4eeff70b1f --- /dev/null +++ b/strands-ts/src/telemetry/__tests__/tracer.test.node.ts @@ -0,0 +1,962 @@ +import { describe, it, expect, beforeEach, vi } from 'vitest' +import type { Span, SpanAttributeValue } from '@opentelemetry/api' +import { SpanStatusCode, trace, context } from '@opentelemetry/api' +import { Tracer } from '../tracer.js' +import { Message, TextBlock, ToolResultBlock, ToolUseBlock, CachePointBlock } from '../../types/messages.js' +import { MockSpan, eventAttr } from '../../__fixtures__/mock-span.js' +import { textMessage } from '../../__fixtures__/agent-helpers.js' + +// Partial mock: keep real SpanStatusCode etc., replace context and trace +vi.mock('@opentelemetry/api', async (importOriginal) => ({ + ...(await importOriginal()), + context: { active: vi.fn(() => ({})), with: vi.fn((_ctx: unknown, fn: () => unknown) => fn()) }, + trace: { + getTracer: vi.fn(), + setSpan: vi.fn(), + }, +})) + +describe('Tracer', () => { + let mockSpan: MockSpan + let mockStartSpan: ReturnType Span>> + + beforeEach(() => { + mockSpan = new MockSpan() + mockStartSpan = vi.fn<(name: string, ...args: unknown[]) => Span>().mockReturnValue(mockSpan) + + vi.mocked(trace.getTracer).mockReturnValue({ + startSpan: mockStartSpan, + startActiveSpan: vi.fn(), + }) + + // Default to stable conventions; tests needing latest override this + vi.stubEnv('OTEL_SEMCONV_STABILITY_OPT_IN', '') + }) + + /** Get the [spanName, options] from the startSpan call for the span under test. */ + function getStartSpanCall(): [string, { attributes: Record }] { + return mockStartSpan.mock.calls[0] as [string, { attributes: Record }] + } + + describe('constructor', () => { + it('reads service name from OTEL_SERVICE_NAME env var', () => { + vi.stubEnv('OTEL_SERVICE_NAME', 'my-custom-service') + + new Tracer() + + expect(trace.getTracer).toHaveBeenCalledWith('my-custom-service') + }) + + it('defaults service name to strands-agents', () => { + vi.stubEnv('OTEL_SERVICE_NAME', '') + + new Tracer() + + expect(trace.getTracer).toHaveBeenCalledWith('strands-agents') + }) + }) + + describe('startAgentSpan', () => { + it('creates span with correct name and standard attributes', () => { + const tracer = new Tracer() + + tracer.startAgentSpan({ + messages: [textMessage('user', 'Hello')], + agentName: 'test-agent', + modelId: 'model-123', + }) + + const [spanName, options] = getStartSpanCall() + expect(spanName).toBe('invoke_agent test-agent') + expect(options.attributes).toMatchObject({ + 'gen_ai.operation.name': 'invoke_agent', + 'gen_ai.system': expect.any(String), + 'gen_ai.agent.name': 'test-agent', + 'gen_ai.request.model': 'model-123', + name: 'invoke_agent test-agent', + }) + }) + + it('includes agent id when provided', () => { + const tracer = new Tracer() + + tracer.startAgentSpan({ + messages: [textMessage('user', 'Hello')], + agentName: 'test-agent', + agentId: 'agent-42', + }) + + const [, options] = getStartSpanCall() + expect(options.attributes['gen_ai.agent.id']).toBe('agent-42') + }) + + it('serializes tool names into gen_ai.agent.tools', () => { + const tracer = new Tracer() + + tracer.startAgentSpan({ + messages: [textMessage('user', 'Hello')], + agentName: 'test-agent', + tools: [{ name: 'calculator' }, { name: 'search' }], + }) + + const [, options] = getStartSpanCall() + expect(options.attributes['gen_ai.agent.tools']).toBe('["calculator","search"]') + }) + + it('includes tool definitions when gen_ai_tool_definitions opt-in is set', () => { + vi.stubEnv('OTEL_SEMCONV_STABILITY_OPT_IN', 'gen_ai_tool_definitions') + const tracer = new Tracer() + const toolsConfig = { calc: { name: 'calc', description: 'Calculator' } } + + tracer.startAgentSpan({ + messages: [textMessage('user', 'Hello')], + agentName: 'test-agent', + toolsConfig, + }) + + const [, options] = getStartSpanCall() + expect(options.attributes['gen_ai.tool.definitions']).toBe(JSON.stringify(toolsConfig)) + }) + + it('serializes system prompt into attribute', () => { + const tracer = new Tracer() + + tracer.startAgentSpan({ + messages: [textMessage('user', 'Hello')], + agentName: 'test-agent', + systemPrompt: 'You are a helpful assistant', + }) + + const [, options] = getStartSpanCall() + expect(options.attributes['system_prompt']).toBe('"You are a helpful assistant"') + }) + + it('merges constructor-level and call-level trace attributes', () => { + const tracer = new Tracer({ 'global.attr': 'global-val' }) + + tracer.startAgentSpan({ + messages: [textMessage('user', 'Hello')], + agentName: 'test-agent', + traceAttributes: { 'custom.session': 'sess-1' }, + }) + + const [, options] = getStartSpanCall() + expect(options.attributes['global.attr']).toBe('global-val') + expect(options.attributes['custom.session']).toBe('sess-1') + }) + + it('adds separate stable message events per message', () => { + const tracer = new Tracer() + + tracer.startAgentSpan({ + messages: [textMessage('user', 'Hello'), textMessage('assistant', 'Hi')], + agentName: 'test-agent', + }) + + expect(mockSpan.getEvents('gen_ai.user.message')).toHaveLength(1) + expect(mockSpan.getEvents('gen_ai.assistant.message')).toHaveLength(1) + }) + + it('classifies tool result messages as gen_ai.tool.message', () => { + const tracer = new Tracer() + + const toolResultMsg = new Message({ + role: 'user', + content: [new ToolResultBlock({ toolUseId: 'tool-1', status: 'success', content: [new TextBlock('done')] })], + }) + + tracer.startAgentSpan({ messages: [toolResultMsg], agentName: 'test-agent' }) + + expect(mockSpan.getEvents('gen_ai.tool.message')).toHaveLength(1) + }) + + it('adds single operation details event with latest conventions', () => { + vi.stubEnv('OTEL_SEMCONV_STABILITY_OPT_IN', 'gen_ai_latest_experimental') + const tracer = new Tracer() + + tracer.startAgentSpan({ + messages: [textMessage('user', 'Hello'), textMessage('assistant', 'Hi')], + agentName: 'test-agent', + }) + + const detailEvents = mockSpan.getEvents('gen_ai.client.inference.operation.details') + expect(detailEvents).toHaveLength(1) + + const inputMessages = JSON.parse(eventAttr(detailEvents[0]!, 'gen_ai.input.messages')) + expect(inputMessages).toStrictEqual([ + { role: 'user', parts: [{ type: 'text', content: 'Hello' }] }, + { role: 'assistant', parts: [{ type: 'text', content: 'Hi' }] }, + ]) + }) + + it('uses gen_ai.provider.name with latest conventions', () => { + vi.stubEnv('OTEL_SEMCONV_STABILITY_OPT_IN', 'gen_ai_latest_experimental') + const tracer = new Tracer() + + tracer.startAgentSpan({ messages: [textMessage('user', 'Hello')], agentName: 'test-agent' }) + + const [, options] = getStartSpanCall() + expect(options.attributes['gen_ai.provider.name']).toBeDefined() + expect(options.attributes['gen_ai.system']).toBeUndefined() + }) + + it('uses gen_ai.system with stable conventions', () => { + const tracer = new Tracer() + + tracer.startAgentSpan({ messages: [textMessage('user', 'Hello')], agentName: 'test-agent' }) + + const [, options] = getStartSpanCall() + expect(options.attributes['gen_ai.system']).toBeDefined() + expect(options.attributes['gen_ai.provider.name']).toBeUndefined() + }) + }) + + describe('endAgentSpan', () => { + it('sets OK status and ends span on success', () => { + const tracer = new Tracer() + const span = tracer.startAgentSpan({ messages: [textMessage('user', 'Hi')], agentName: 'agent' }) + + tracer.endAgentSpan(span) + + expect(mockSpan.calls.setStatus).toContainEqual({ status: { code: SpanStatusCode.OK } }) + expect(mockSpan.calls.end).toHaveLength(1) + }) + + it('sets ERROR status and records exception on error', () => { + const tracer = new Tracer() + const span = tracer.startAgentSpan({ messages: [textMessage('user', 'Hi')], agentName: 'agent' }) + const error = new Error('agent failed') + + tracer.endAgentSpan(span, { error }) + + expect(mockSpan.calls.setStatus).toContainEqual({ + status: { code: SpanStatusCode.ERROR, message: 'agent failed' }, + }) + expect(mockSpan.calls.recordException).toContainEqual({ exception: error, time: undefined }) + }) + + it('sets accumulated usage attributes', () => { + const tracer = new Tracer() + const span = tracer.startAgentSpan({ messages: [textMessage('user', 'Hi')], agentName: 'agent' }) + + tracer.endAgentSpan(span, { + accumulatedUsage: { inputTokens: 100, outputTokens: 200, totalTokens: 300 }, + }) + + expect(mockSpan.getAttributeValue('gen_ai.usage.input_tokens')).toBe(100) + expect(mockSpan.getAttributeValue('gen_ai.usage.output_tokens')).toBe(200) + expect(mockSpan.getAttributeValue('gen_ai.usage.total_tokens')).toBe(300) + expect(mockSpan.getAttributeValue('gen_ai.usage.prompt_tokens')).toBe(100) + expect(mockSpan.getAttributeValue('gen_ai.usage.completion_tokens')).toBe(200) + }) + + it('adds response event with stable conventions', () => { + const tracer = new Tracer() + const span = tracer.startAgentSpan({ messages: [textMessage('user', 'Hi')], agentName: 'agent' }) + + const response = new Message({ role: 'assistant', content: [new TextBlock('Hello back')] }) + tracer.endAgentSpan(span, { response, stopReason: 'end_turn' }) + + const choiceEvents = mockSpan.getEvents('gen_ai.choice') + expect(choiceEvents).toHaveLength(1) + expect(eventAttr(choiceEvents[0]!, 'message')).toBe('Hello back') + expect(eventAttr(choiceEvents[0]!, 'finish_reason')).toBe('end_turn') + }) + + it('adds response event with latest conventions', () => { + vi.stubEnv('OTEL_SEMCONV_STABILITY_OPT_IN', 'gen_ai_latest_experimental') + const tracer = new Tracer() + const span = tracer.startAgentSpan({ messages: [textMessage('user', 'Hi')], agentName: 'agent' }) + + const response = new Message({ role: 'assistant', content: [new TextBlock('Hello back')] }) + tracer.endAgentSpan(span, { response, stopReason: 'end_turn' }) + + const detailEvents = mockSpan.getEvents('gen_ai.client.inference.operation.details') + const outputEvent = detailEvents.find((e) => eventAttr(e, 'gen_ai.output.messages')) + expect(outputEvent).toBeDefined() + const parsed = JSON.parse(eventAttr(outputEvent!, 'gen_ai.output.messages')) + expect(parsed).toStrictEqual([ + { role: 'assistant', parts: [{ type: 'text', content: 'Hello back' }], finish_reason: 'end_turn' }, + ]) + }) + + it('handles null span gracefully', () => { + const tracer = new Tracer() + + expect(() => tracer.endAgentSpan(null)).not.toThrow() + expect(mockSpan.calls.end).toHaveLength(0) + }) + }) + + describe('startModelInvokeSpan', () => { + it('creates span with chat operation name and model id', () => { + const tracer = new Tracer() + + tracer.startModelInvokeSpan({ messages: [textMessage('user', 'Hello')], modelId: 'claude-3' }) + + const [spanName, options] = getStartSpanCall() + expect(spanName).toBe('chat') + expect(options.attributes).toMatchObject({ + 'gen_ai.operation.name': 'chat', + 'gen_ai.request.model': 'claude-3', + }) + }) + + it('adds message events to span', () => { + const tracer = new Tracer() + + tracer.startModelInvokeSpan({ messages: [textMessage('user', 'Hello')] }) + + expect(mockSpan.getEvents('gen_ai.user.message')).toHaveLength(1) + }) + }) + + describe('endModelInvokeSpan', () => { + it('sets usage and metrics attributes', () => { + const tracer = new Tracer() + const span = tracer.startModelInvokeSpan({ messages: [textMessage('user', 'Hi')], modelId: 'model-1' }) + + tracer.endModelInvokeSpan(span, { + usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 }, + metrics: { latencyMs: 500 }, + }) + + expect(mockSpan.getAttributeValue('gen_ai.usage.input_tokens')).toBe(10) + expect(mockSpan.getAttributeValue('gen_ai.usage.output_tokens')).toBe(20) + expect(mockSpan.getAttributeValue('gen_ai.usage.total_tokens')).toBe(30) + expect(mockSpan.getAttributeValue('gen_ai.server.request.duration')).toBe(500) + }) + + it('sets cache token attributes when provided', () => { + const tracer = new Tracer() + const span = tracer.startModelInvokeSpan({ messages: [textMessage('user', 'Hi')] }) + + tracer.endModelInvokeSpan(span, { + usage: { + inputTokens: 100, + outputTokens: 200, + totalTokens: 300, + cacheReadInputTokens: 50, + cacheWriteInputTokens: 25, + }, + }) + + expect(mockSpan.getAttributeValue('gen_ai.usage.cache_read_input_tokens')).toBe(50) + expect(mockSpan.getAttributeValue('gen_ai.usage.cache_write_input_tokens')).toBe(25) + }) + + it('skips cache token attributes when zero', () => { + const tracer = new Tracer() + const span = tracer.startModelInvokeSpan({ messages: [textMessage('user', 'Hi')] }) + + tracer.endModelInvokeSpan(span, { + usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30, cacheReadInputTokens: 0 }, + }) + + expect(mockSpan.getAttributeValue('gen_ai.usage.cache_read_input_tokens')).toBeUndefined() + }) + + it('skips latency attribute when zero', () => { + const tracer = new Tracer() + const span = tracer.startModelInvokeSpan({ messages: [textMessage('user', 'Hi')] }) + + tracer.endModelInvokeSpan(span, { + usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 }, + metrics: { latencyMs: 0 }, + }) + + expect(mockSpan.getAttributeValue('gen_ai.server.request.duration')).toBeUndefined() + }) + + it('adds output event with stable conventions for mixed content', () => { + const tracer = new Tracer() + const span = tracer.startModelInvokeSpan({ messages: [textMessage('user', 'Hi')] }) + + const output = new Message({ + role: 'assistant', + content: [ + new TextBlock('The answer is 42'), + new ToolUseBlock({ name: 'calc', toolUseId: 'tool-1', input: { expr: '6*7' } }), + ], + }) + + tracer.endModelInvokeSpan(span, { output, stopReason: 'tool_use' }) + + const choiceEvents = mockSpan.getEvents('gen_ai.choice') + expect(choiceEvents).toHaveLength(1) + expect(eventAttr(choiceEvents[0]!, 'finish_reason')).toBe('tool_use') + + const parsed = JSON.parse(eventAttr(choiceEvents[0]!, 'message')) + expect(parsed).toStrictEqual([ + { text: 'The answer is 42' }, + { type: 'toolUse', name: 'calc', toolUseId: 'tool-1', input: { expr: '6*7' } }, + ]) + }) + + it('adds output event with latest conventions for mixed content', () => { + vi.stubEnv('OTEL_SEMCONV_STABILITY_OPT_IN', 'gen_ai_latest_experimental') + const tracer = new Tracer() + const span = tracer.startModelInvokeSpan({ messages: [textMessage('user', 'Hi')] }) + + const output = new Message({ + role: 'assistant', + content: [ + new TextBlock('The answer'), + new ToolUseBlock({ name: 'calc', toolUseId: 'tool-1', input: { x: 1 } }), + ], + }) + + tracer.endModelInvokeSpan(span, { output, stopReason: 'tool_use' }) + + const detailEvents = mockSpan.getEvents('gen_ai.client.inference.operation.details') + const outputEvent = detailEvents.find((e) => eventAttr(e, 'gen_ai.output.messages')) + expect(outputEvent).toBeDefined() + const parsed = JSON.parse(eventAttr(outputEvent!, 'gen_ai.output.messages')) + expect(parsed).toStrictEqual([ + { + role: 'assistant', + parts: [ + { type: 'text', content: 'The answer' }, + { type: 'tool_call', name: 'calc', id: 'tool-1', arguments: { x: 1 } }, + ], + finish_reason: 'tool_use', + }, + ]) + }) + + it('records error on model invocation failure', () => { + const tracer = new Tracer() + const span = tracer.startModelInvokeSpan({ messages: [textMessage('user', 'Hi')] }) + const error = new Error('model timeout') + + tracer.endModelInvokeSpan(span, { error }) + + expect(mockSpan.calls.setStatus).toContainEqual({ + status: { code: SpanStatusCode.ERROR, message: 'model timeout' }, + }) + expect(mockSpan.calls.recordException).toContainEqual({ exception: error, time: undefined }) + }) + + it('handles null span gracefully', () => { + const tracer = new Tracer() + + expect(() => tracer.endModelInvokeSpan(null)).not.toThrow() + }) + }) + + describe('startToolCallSpan', () => { + it('creates span with tool name and call id', () => { + const tracer = new Tracer() + + tracer.startToolCallSpan({ + tool: { name: 'calculator', toolUseId: 'call-1', input: { expr: '2+2' } }, + }) + + const [spanName, options] = getStartSpanCall() + expect(spanName).toBe('execute_tool calculator') + expect(options.attributes).toMatchObject({ + 'gen_ai.operation.name': 'execute_tool', + 'gen_ai.tool.name': 'calculator', + 'gen_ai.tool.call.id': 'call-1', + }) + }) + + it('adds stable tool message event with serialized input', () => { + const tracer = new Tracer() + + tracer.startToolCallSpan({ + tool: { name: 'search', toolUseId: 'call-2', input: { query: 'test' } }, + }) + + const toolEvents = mockSpan.getEvents('gen_ai.tool.message') + expect(toolEvents).toHaveLength(1) + expect(eventAttr(toolEvents[0]!, 'role')).toBe('tool') + expect(eventAttr(toolEvents[0]!, 'content')).toBe('{"query":"test"}') + expect(eventAttr(toolEvents[0]!, 'id')).toBe('call-2') + }) + + it('adds latest convention tool input event', () => { + vi.stubEnv('OTEL_SEMCONV_STABILITY_OPT_IN', 'gen_ai_latest_experimental') + const tracer = new Tracer() + + tracer.startToolCallSpan({ + tool: { name: 'search', toolUseId: 'call-2', input: { query: 'test' } }, + }) + + const detailEvents = mockSpan.getEvents('gen_ai.client.inference.operation.details') + expect(detailEvents).toHaveLength(1) + const parsed = JSON.parse(eventAttr(detailEvents[0]!, 'gen_ai.input.messages')) + expect(parsed).toStrictEqual([ + { + role: 'tool', + parts: [{ type: 'tool_call', name: 'search', id: 'call-2', arguments: { query: 'test' } }], + }, + ]) + }) + }) + + describe('endToolCallSpan', () => { + it('sets tool status attribute and adds stable result event', () => { + const tracer = new Tracer() + const span = tracer.startToolCallSpan({ + tool: { name: 'calc', toolUseId: 'call-1', input: {} }, + }) + + const toolResult = new ToolResultBlock({ + toolUseId: 'call-1', + status: 'success', + content: [new TextBlock('42')], + }) + + tracer.endToolCallSpan(span, { toolResult }) + + expect(mockSpan.getAttributeValue('gen_ai.tool.status')).toBe('success') + + const choiceEvents = mockSpan.getEvents('gen_ai.choice') + expect(choiceEvents).toHaveLength(1) + expect(eventAttr(choiceEvents[0]!, 'id')).toBe('call-1') + }) + + it('adds latest convention tool result event', () => { + vi.stubEnv('OTEL_SEMCONV_STABILITY_OPT_IN', 'gen_ai_latest_experimental') + const tracer = new Tracer() + const span = tracer.startToolCallSpan({ + tool: { name: 'calc', toolUseId: 'call-1', input: {} }, + }) + + const toolResult = new ToolResultBlock({ + toolUseId: 'call-1', + status: 'success', + content: [new TextBlock('42')], + }) + + tracer.endToolCallSpan(span, { toolResult }) + + const detailEvents = mockSpan.getEvents('gen_ai.client.inference.operation.details') + const outputEvent = detailEvents.find((e) => eventAttr(e, 'gen_ai.output.messages')) + expect(outputEvent).toBeDefined() + const parsed = JSON.parse(eventAttr(outputEvent!, 'gen_ai.output.messages')) + expect(parsed[0].role).toBe('tool') + expect(parsed[0].parts[0].type).toBe('tool_call_response') + expect(parsed[0].parts[0].id).toBe('call-1') + }) + + it('records error on tool failure', () => { + const tracer = new Tracer() + const span = tracer.startToolCallSpan({ + tool: { name: 'calc', toolUseId: 'call-1', input: {} }, + }) + const error = new Error('tool crashed') + + tracer.endToolCallSpan(span, { error }) + + expect(mockSpan.calls.setStatus).toContainEqual({ + status: { code: SpanStatusCode.ERROR, message: 'tool crashed' }, + }) + expect(mockSpan.calls.recordException).toContainEqual({ exception: error, time: undefined }) + }) + + it('handles null span gracefully', () => { + const tracer = new Tracer() + + expect(() => tracer.endToolCallSpan(null)).not.toThrow() + }) + }) + + describe('startAgentLoopSpan', () => { + it('creates span with cycle id attribute', () => { + const tracer = new Tracer() + + tracer.startAgentLoopSpan({ cycleId: 'cycle-42', messages: [textMessage('user', 'Hi')] }) + + const [spanName, options] = getStartSpanCall() + expect(spanName).toBe('execute_agent_loop_cycle') + expect(options.attributes['agent_loop.cycle_id']).toBe('cycle-42') + }) + + it('adds message events to loop span', () => { + const tracer = new Tracer() + + tracer.startAgentLoopSpan({ cycleId: 'cycle-1', messages: [textMessage('user', 'Hello')] }) + + expect(mockSpan.getEvents('gen_ai.user.message')).toHaveLength(1) + }) + + it('creates local trace with cycleId in metadata', () => { + const tracer = new Tracer() + + tracer.startAgentLoopSpan({ cycleId: 'cycle-123', messages: [] }) + + const traces = tracer.localTraces + expect(traces).toEqual([ + expect.objectContaining({ + name: 'Cycle 1', + metadata: expect.objectContaining({ cycleId: 'cycle-123' }), + }), + ]) + }) + + it('stores unique cycleIds for multiple cycles', () => { + const tracer = new Tracer() + + tracer.startAgentLoopSpan({ cycleId: 'cycle-abc', messages: [] }) + tracer.endAgentLoopSpan(mockSpan) + tracer.startAgentLoopSpan({ cycleId: 'cycle-xyz', messages: [] }) + + const traces = tracer.localTraces + expect(traces).toEqual([ + expect.objectContaining({ + name: 'Cycle 1', + metadata: expect.objectContaining({ cycleId: 'cycle-abc' }), + }), + expect.objectContaining({ + name: 'Cycle 2', + metadata: expect.objectContaining({ cycleId: 'cycle-xyz' }), + }), + ]) + }) + }) + + describe('endAgentLoopSpan', () => { + it('ends span with OK status', () => { + const tracer = new Tracer() + const span = tracer.startAgentLoopSpan({ cycleId: 'cycle-1', messages: [textMessage('user', 'Hi')] }) + + tracer.endAgentLoopSpan(span) + + expect(mockSpan.calls.setStatus).toContainEqual({ status: { code: SpanStatusCode.OK } }) + expect(mockSpan.calls.end).toHaveLength(1) + }) + + it('records error on loop failure', () => { + const tracer = new Tracer() + const span = tracer.startAgentLoopSpan({ cycleId: 'cycle-1', messages: [textMessage('user', 'Hi')] }) + const error = new Error('loop failed') + + tracer.endAgentLoopSpan(span, { error }) + + expect(mockSpan.calls.setStatus).toContainEqual({ + status: { code: SpanStatusCode.ERROR, message: 'loop failed' }, + }) + expect(mockSpan.calls.recordException).toContainEqual({ exception: error, time: undefined }) + }) + + it('handles null span gracefully', () => { + const tracer = new Tracer() + + expect(() => tracer.endAgentLoopSpan(null)).not.toThrow() + }) + }) + + describe('withSpanContext', () => { + it('executes callback directly when span is null', () => { + const tracer = new Tracer() + const fn = vi.fn(() => 'result') + + const result = tracer.withSpanContext(null, fn) + + expect(result).toBe('result') + expect(fn).toHaveBeenCalledOnce() + expect(context.with).not.toHaveBeenCalled() + }) + + it('executes callback within span context when span is provided', () => { + const tracer = new Tracer() + const span = tracer.startAgentSpan({ messages: [textMessage('user', 'Hi')], agentName: 'agent' }) + const mockContext = { spanContext: true } + vi.mocked(trace.setSpan).mockReturnValue(mockContext as never) + + tracer.withSpanContext(span, () => 'inside') + + expect(trace.setSpan).toHaveBeenCalledWith({}, span) + expect(context.with).toHaveBeenCalledWith(mockContext, expect.any(Function)) + }) + + it('propagates return value from callback', () => { + const tracer = new Tracer() + const span = tracer.startAgentSpan({ messages: [textMessage('user', 'Hi')], agentName: 'agent' }) + + const result = tracer.withSpanContext(span, () => 42) + + expect(result).toBe(42) + }) + }) + + describe('message event formatting', () => { + it('maps tool use blocks to tool_call parts in latest conventions', () => { + vi.stubEnv('OTEL_SEMCONV_STABILITY_OPT_IN', 'gen_ai_latest_experimental') + const tracer = new Tracer() + + const messages = [ + new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'search', toolUseId: 'tu-1', input: { q: 'test' } })], + }), + ] + + tracer.startAgentSpan({ messages, agentName: 'agent' }) + + const detailEvents = mockSpan.getEvents('gen_ai.client.inference.operation.details') + const parsed = JSON.parse(eventAttr(detailEvents[0]!, 'gen_ai.input.messages')) + expect(parsed[0].parts[0]).toStrictEqual({ + type: 'tool_call', + name: 'search', + id: 'tu-1', + arguments: { q: 'test' }, + }) + }) + + it('maps tool result blocks to tool_call_response parts in latest conventions', () => { + vi.stubEnv('OTEL_SEMCONV_STABILITY_OPT_IN', 'gen_ai_latest_experimental') + const tracer = new Tracer() + + const messages = [ + new Message({ + role: 'user', + content: [new ToolResultBlock({ toolUseId: 'tu-1', status: 'success', content: [new TextBlock('result')] })], + }), + ] + + tracer.startAgentSpan({ messages, agentName: 'agent' }) + + const detailEvents = mockSpan.getEvents('gen_ai.client.inference.operation.details') + const parsed = JSON.parse(eventAttr(detailEvents[0]!, 'gen_ai.input.messages')) + expect(parsed[0].parts[0].type).toBe('tool_call_response') + expect(parsed[0].parts[0].id).toBe('tu-1') + }) + + it('serializes text block content in stable convention events', () => { + const tracer = new Tracer() + + tracer.startModelInvokeSpan({ messages: [textMessage('user', 'Hello world')] }) + + const userEvents = mockSpan.getEvents('gen_ai.user.message') + const parsed = JSON.parse(eventAttr(userEvents[0]!, 'content')) + expect(parsed[0].text).toBe('Hello world') + }) + }) + + describe('system prompt on chat spans', () => { + it('emits gen_ai.system.message event with stable conventions', () => { + const tracer = new Tracer() + + tracer.startModelInvokeSpan({ + messages: [textMessage('user', 'Hello')], + modelId: 'test-model', + systemPrompt: 'You are a helpful assistant', + }) + + const systemEvents = mockSpan.getEvents('gen_ai.system.message') + expect(systemEvents).toHaveLength(1) + expect(JSON.parse(eventAttr(systemEvents[0]!, 'content'))).toStrictEqual([ + { text: 'You are a helpful assistant' }, + ]) + }) + + it('emits gen_ai.system_instructions with latest conventions', () => { + vi.stubEnv('OTEL_SEMCONV_STABILITY_OPT_IN', 'gen_ai_latest_experimental') + const tracer = new Tracer() + + tracer.startModelInvokeSpan({ + messages: [textMessage('user', 'Hello')], + modelId: 'test-model', + systemPrompt: 'You are a calculator assistant', + }) + + const detailEvents = mockSpan.getEvents('gen_ai.client.inference.operation.details') + const systemEvent = detailEvents.find((e) => eventAttr(e, 'gen_ai.system_instructions')) + expect(systemEvent).toBeDefined() + expect(JSON.parse(eventAttr(systemEvent!, 'gen_ai.system_instructions'))).toStrictEqual([ + { type: 'text', content: 'You are a calculator assistant' }, + ]) + }) + + it('does not emit system prompt event when systemPrompt is undefined', () => { + const tracer = new Tracer() + + tracer.startModelInvokeSpan({ + messages: [textMessage('user', 'Hello')], + modelId: 'test-model', + }) + + const systemEvents = mockSpan.getEvents('gen_ai.system.message') + expect(systemEvents).toHaveLength(0) + }) + + it('handles SystemContentBlock array with cache points in latest conventions', () => { + vi.stubEnv('OTEL_SEMCONV_STABILITY_OPT_IN', 'gen_ai_latest_experimental') + const tracer = new Tracer() + + tracer.startModelInvokeSpan({ + messages: [textMessage('user', 'Hello')], + modelId: 'test-model', + systemPrompt: [new TextBlock('You are helpful'), new CachePointBlock({ cacheType: 'default' })], + }) + + const detailEvents = mockSpan.getEvents('gen_ai.client.inference.operation.details') + const systemEvent = detailEvents.find((e) => eventAttr(e, 'gen_ai.system_instructions')) + expect(systemEvent).toBeDefined() + expect(JSON.parse(eventAttr(systemEvent!, 'gen_ai.system_instructions'))).toStrictEqual([ + { type: 'text', content: 'You are helpful' }, + { type: 'cache_point', cacheType: 'default' }, + ]) + }) + + it('serializes SystemContentBlock array in stable conventions', () => { + const tracer = new Tracer() + + tracer.startModelInvokeSpan({ + messages: [textMessage('user', 'Hello')], + modelId: 'test-model', + systemPrompt: [new TextBlock('You are helpful'), new CachePointBlock({ cacheType: 'default' })], + }) + + const systemEvents = mockSpan.getEvents('gen_ai.system.message') + expect(systemEvents).toHaveLength(1) + const parsed = JSON.parse(eventAttr(systemEvents[0]!, 'content')) + expect(parsed).toHaveLength(2) + expect(parsed[0]).toStrictEqual({ text: 'You are helpful' }) + expect(parsed[1]).toStrictEqual({ cachePoint: { cacheType: 'default' } }) + }) + }) + + describe('timeToFirstByteMs', () => { + it('does not set TTFB as span attribute (TTFB is a histogram metric, not a span attribute)', () => { + const tracer = new Tracer() + const span = tracer.startModelInvokeSpan({ messages: [textMessage('user', 'Hi')] }) + + tracer.endModelInvokeSpan(span, { + usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 }, + metrics: { latencyMs: 500, timeToFirstByteMs: 150 }, + }) + + expect(mockSpan.getAttributeValue('gen_ai.server.time_to_first_token')).toBeUndefined() + expect(mockSpan.getAttributeValue('gen_ai.server.request.duration')).toBe(500) + }) + }) + + describe('Langfuse detection', () => { + it('sets langfuse.observation.type on agent span when OTEL_EXPORTER_OTLP_ENDPOINT contains langfuse', () => { + vi.stubEnv('OTEL_EXPORTER_OTLP_ENDPOINT', 'https://us.cloud.langfuse.com') + const tracer = new Tracer() + const span = tracer.startAgentSpan({ messages: [textMessage('user', 'Hi')], agentName: 'agent' }) + + tracer.endAgentSpan(span, { + accumulatedUsage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 }, + }) + + expect(mockSpan.getAttributeValue('langfuse.observation.type')).toBe('span') + }) + + it('sets langfuse.observation.type when OTEL_EXPORTER_OTLP_TRACES_ENDPOINT contains langfuse', () => { + vi.stubEnv('OTEL_EXPORTER_OTLP_TRACES_ENDPOINT', 'https://us.cloud.langfuse.com/api/public/otel/v1/traces') + const tracer = new Tracer() + const span = tracer.startAgentSpan({ messages: [textMessage('user', 'Hi')], agentName: 'agent' }) + + tracer.endAgentSpan(span, { + accumulatedUsage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 }, + }) + + expect(mockSpan.getAttributeValue('langfuse.observation.type')).toBe('span') + }) + + it('sets langfuse.observation.type when LANGFUSE_BASE_URL is set', () => { + vi.stubEnv('LANGFUSE_BASE_URL', 'https://self-hosted.example.com') + const tracer = new Tracer() + const span = tracer.startAgentSpan({ messages: [textMessage('user', 'Hi')], agentName: 'agent' }) + + tracer.endAgentSpan(span, { + accumulatedUsage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 }, + }) + + expect(mockSpan.getAttributeValue('langfuse.observation.type')).toBe('span') + }) + + it('does not set langfuse.observation.type when no langfuse env vars are set', () => { + vi.stubEnv('OTEL_EXPORTER_OTLP_ENDPOINT', '') + vi.stubEnv('OTEL_EXPORTER_OTLP_TRACES_ENDPOINT', '') + vi.stubEnv('LANGFUSE_BASE_URL', '') + const tracer = new Tracer() + const span = tracer.startAgentSpan({ messages: [textMessage('user', 'Hi')], agentName: 'agent' }) + + tracer.endAgentSpan(span, { + accumulatedUsage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 }, + }) + + expect(mockSpan.getAttributeValue('langfuse.observation.type')).toBeUndefined() + }) + }) + + describe('error resilience', () => { + it.each([ + { + method: 'startAgentSpan', + call: (tracer: Tracer) => tracer.startAgentSpan({ messages: [textMessage('user', 'Hi')], agentName: 'agent' }), + }, + { + method: 'startModelInvokeSpan', + call: (tracer: Tracer) => tracer.startModelInvokeSpan({ messages: [textMessage('user', 'Hi')] }), + }, + { + method: 'startToolCallSpan', + call: (tracer: Tracer) => tracer.startToolCallSpan({ tool: { name: 'x', toolUseId: 'y', input: {} } }), + }, + { + method: 'startAgentLoopSpan', + call: (tracer: Tracer) => tracer.startAgentLoopSpan({ cycleId: 'c', messages: [textMessage('user', 'Hi')] }), + }, + ])('returns null when $method throws internally', ({ call }) => { + mockStartSpan.mockImplementation(() => { + throw new Error('otel failure') + }) + const tracer = new Tracer() + + expect(call(tracer)).toBeNull() + }) + + it('does not throw when ending null spans with errors', () => { + const tracer = new Tracer() + + expect(() => { + tracer.endAgentSpan(null, { error: new Error('test') }) + tracer.endModelInvokeSpan(null, { error: new Error('test') }) + tracer.endToolCallSpan(null, { error: new Error('test') }) + tracer.endAgentLoopSpan(null, { error: new Error('test') }) + }).not.toThrow() + }) + }) + + describe('semantic convention opt-in parsing', () => { + it('parses multiple comma-separated opt-in values', () => { + vi.stubEnv('OTEL_SEMCONV_STABILITY_OPT_IN', 'gen_ai_latest_experimental,gen_ai_tool_definitions') + const tracer = new Tracer() + const toolsConfig = { calc: { name: 'calc', description: 'Calculator' } } + + tracer.startAgentSpan({ messages: [textMessage('user', 'Hi')], agentName: 'agent', toolsConfig }) + + const [, options] = getStartSpanCall() + expect(options.attributes['gen_ai.provider.name']).toBeDefined() + expect(options.attributes['gen_ai.tool.definitions']).toBe(JSON.stringify(toolsConfig)) + }) + + it('handles whitespace in opt-in values', () => { + vi.stubEnv('OTEL_SEMCONV_STABILITY_OPT_IN', ' gen_ai_latest_experimental , gen_ai_tool_definitions ') + const tracer = new Tracer() + + tracer.startAgentSpan({ messages: [textMessage('user', 'Hi')], agentName: 'agent' }) + + const [, options] = getStartSpanCall() + expect(options.attributes['gen_ai.provider.name']).toBeDefined() + }) + + it('defaults to stable conventions when env var is empty', () => { + const tracer = new Tracer() + + tracer.startAgentSpan({ messages: [textMessage('user', 'Hi')], agentName: 'agent' }) + + const [, options] = getStartSpanCall() + expect(options.attributes['gen_ai.system']).toBeDefined() + expect(options.attributes['gen_ai.provider.name']).toBeUndefined() + }) + }) +}) diff --git a/strands-ts/src/telemetry/config.ts b/strands-ts/src/telemetry/config.ts new file mode 100644 index 0000000000..ad8538702f --- /dev/null +++ b/strands-ts/src/telemetry/config.ts @@ -0,0 +1,287 @@ +/** + * OpenTelemetry configuration and setup utilities for Strands agents. + * + * Provides {@link setupTracer} for distributed tracing and {@link setupMeter} + * for OTEL metrics export. Both use the global OTel API so any provider + * registered here (or by the user) is automatically picked up by the Agent. + * + * This module is only loaded when the user explicitly imports and calls + * {@link setupTracer} or {@link setupMeter}. The core agent loop + * (tracer.ts, meter.ts) does not depend on this module. + * + * Uses NodeTracerProvider when available for async context propagation + * across MCP server boundaries. Falls back to BasicTracerProvider in + * environments without async_hooks support. + */ + +import { context as otelContext, metrics as otelMetrics, propagation, trace } from '@opentelemetry/api' +import type { + ContextManager, + Meter as OtelMeter, + TextMapPropagator, + TracerProvider, + Tracer as OtelTracer, +} from '@opentelemetry/api' +import { resourceFromAttributes, envDetector, type Resource } from '@opentelemetry/resources' +import { + BasicTracerProvider, + ConsoleSpanExporter, + SimpleSpanProcessor, + BatchSpanProcessor, + type SpanProcessor, +} from '@opentelemetry/sdk-trace-base' +import { + MeterProvider, + PeriodicExportingMetricReader, + ConsoleMetricExporter, + type MetricReader, +} from '@opentelemetry/sdk-metrics' +import { OTLPTraceExporter } from '@opentelemetry/exporter-trace-otlp-http' +import { OTLPMetricExporter } from '@opentelemetry/exporter-metrics-otlp-http' +import { logger } from '../logging/index.js' +import { getServiceName } from './utils.js' + +let DefaultTracerProvider: typeof BasicTracerProvider = BasicTracerProvider +let DefaultContextManager: (new () => ContextManager) | undefined +let DefaultPropagator: TextMapPropagator | undefined +if (typeof globalThis.process?.getBuiltinModule === 'function') { + try { + const nodeModule = globalThis.process.getBuiltinModule('node:module') as typeof import('module') | undefined + if (nodeModule) { + const req = nodeModule.createRequire(import.meta.url) + DefaultTracerProvider = req('@opentelemetry/sdk-trace-node').NodeTracerProvider + DefaultContextManager = req('@opentelemetry/context-async-hooks').AsyncLocalStorageContextManager + const { W3CTraceContextPropagator, W3CBaggagePropagator, CompositePropagator } = req('@opentelemetry/core') + DefaultPropagator = new CompositePropagator({ + propagators: [new W3CTraceContextPropagator(), new W3CBaggagePropagator()], + }) + } + } catch { + logger.info('sdk-trace-node not available | using BasicTracerProvider without async context propagation') + } +} + +const DEFAULT_SERVICE_NAMESPACE = 'strands' +const DEFAULT_DEPLOYMENT_ENVIRONMENT = 'development' + +/** + * Get an OpenTelemetry Tracer instance. + * + * Wraps the OTel trace API to provide a consistent tracer scoped to the + * configured service name. + * + * @returns An OTel Tracer instance from the global tracer provider + * + * @example + * ```typescript + * import { setupTracer, getTracer } from '@strands-agents/sdk/telemetry' + * + * // Set up telemetry first (or register your own NodeTracerProvider) + * setupTracer({ exporters: { otlp: true } }) + * + * // Get a tracer and create custom spans + * const tracer = getTracer() + * const span = tracer.startSpan('my-custom-operation') + * span.setAttribute('custom.key', 'value') + * + * // ........ + * + * span.end() + * ``` + */ +export function getTracer(): OtelTracer { + return trace.getTracer(getServiceName()) +} + +/** + * Get an OpenTelemetry Meter instance. + * + * Wraps the OTel metrics API to provide a consistent meter scoped to the + * configured service name. Returns a no-op meter until a MeterProvider is + * registered (either via {@link setupMeter} or by the user directly). + * + * @returns An OTel Meter instance from the global meter provider + * + * @example + * ```typescript + * import { setupMeter, getMeter } from '@strands-agents/sdk/telemetry' + * + * setupMeter({ exporters: { otlp: true } }) + * + * const meter = getMeter() + * const counter = meter.createCounter('my.custom.counter') + * counter.add(1) + * ``` + */ +export function getMeter(): OtelMeter { + return otelMetrics.getMeter(getServiceName()) +} + +/** + * Configuration options for setting up the tracer. + */ +export interface TracerConfig { + /** + * Custom TracerProvider instance. If not provided, NodeTracerProvider is + * used when available, otherwise BasicTracerProvider. + */ + provider?: TracerProvider + + /** + * Exporter configuration. + */ + exporters?: { + /** + * Enable OTLP exporter. Uses OTEL_EXPORTER_OTLP_ENDPOINT and + * OTEL_EXPORTER_OTLP_HEADERS env vars automatically. + */ + otlp?: boolean + /** + * Enable console exporter for debugging. + */ + console?: boolean + } +} + +let _provider: BasicTracerProvider | null = null +let _customProvider: TracerProvider | null = null + +/** + * Set up the tracer provider with the given configuration. + * + * When called without a custom provider, returns a BasicTracerProvider and + * registers the async context manager + W3C propagators for trace propagation. + * When a custom provider is passed, the caller is responsible for their own + * context manager / propagator setup (e.g. via provider.register()). + * + * @param config - Tracer configuration options + * @returns The configured tracer provider + * + * @example + * ```typescript + * import { telemetry } from '\@strands-agents/sdk' + * + * telemetry.setupTracer({ exporters: { otlp: true } }) + * ``` + */ +export function setupTracer(config?: Omit): BasicTracerProvider +export function setupTracer(config: TracerConfig): TracerProvider +export function setupTracer(config: TracerConfig = {}): TracerProvider { + if (_provider || _customProvider) { + logger.warn('tracer provider already initialized, returning existing provider') + return _customProvider ?? _provider! + } + + if (config.provider) { + _customProvider = config.provider + trace.setGlobalTracerProvider(_customProvider) + return _customProvider + } + + const spanProcessors: SpanProcessor[] = [] + if (config.exporters?.otlp) spanProcessors.push(new BatchSpanProcessor(new OTLPTraceExporter())) + if (config.exporters?.console) spanProcessors.push(new SimpleSpanProcessor(new ConsoleSpanExporter())) + _provider = new DefaultTracerProvider({ resource: getOtelResource(), spanProcessors }) + + trace.setGlobalTracerProvider(_provider) + if (DefaultContextManager) otelContext.setGlobalContextManager(new DefaultContextManager()) + if (DefaultPropagator) propagation.setGlobalPropagator(DefaultPropagator) + + if (typeof globalThis.process?.once === 'function') { + globalThis.process.once('beforeExit', () => { + _provider?.forceFlush()?.catch((err: unknown) => { + logger.warn(`error=<${err}> | failed to flush tracer provider on exit`) + }) + }) + } + + return _provider +} + +/** + * Configuration options for setting up the OTEL meter provider. + */ +export interface MeterConfig { + /** + * Custom MeterProvider instance. When provided, it is registered as the + * global meter provider and the SDK will not create one internally. + */ + provider?: MeterProvider + + /** + * Exporter configuration. + */ + exporters?: { + /** + * Enable OTLP exporter. Uses OTEL_EXPORTER_OTLP_ENDPOINT and + * OTEL_EXPORTER_OTLP_HEADERS env vars automatically. + */ + otlp?: boolean + /** + * Enable console exporter for debugging. + */ + console?: boolean + } +} + +let _meterProvider: MeterProvider | null = null + +/** + * Set up the OTEL meter provider with the given configuration. + * + * @param config - Meter configuration options + * @returns The configured meter provider + * + * @example + * ```typescript + * import { telemetry } from '\@strands-agents/sdk' + * + * telemetry.setupMeter({ exporters: { otlp: true } }) + * ``` + */ +export function setupMeter(config: MeterConfig = {}): MeterProvider { + if (_meterProvider) { + logger.warn('meter provider already initialized, returning existing provider') + return _meterProvider + } + + if (config.provider) { + _meterProvider = config.provider + } else { + const readers: MetricReader[] = [] + if (config.exporters?.otlp) readers.push(new PeriodicExportingMetricReader({ exporter: new OTLPMetricExporter() })) + if (config.exporters?.console) + readers.push(new PeriodicExportingMetricReader({ exporter: new ConsoleMetricExporter() })) + _meterProvider = new MeterProvider({ resource: getOtelResource(), readers }) + } + + otelMetrics.setGlobalMeterProvider(_meterProvider) + + if (typeof globalThis.process?.once === 'function') { + globalThis.process.once('beforeExit', () => { + if (_meterProvider) { + _meterProvider.forceFlush().catch((err: unknown) => { + logger.warn(`error=<${err}> | failed to flush meter provider on exit`) + }) + } + }) + } + + return _meterProvider +} + +function getOtelResource(): Resource { + const serviceName = getServiceName() + const serviceNamespace = globalThis.process?.env?.OTEL_SERVICE_NAMESPACE || DEFAULT_SERVICE_NAMESPACE + const deploymentEnvironment = globalThis.process?.env?.OTEL_DEPLOYMENT_ENVIRONMENT || DEFAULT_DEPLOYMENT_ENVIRONMENT + + const envAttributes = envDetector.detect().attributes ?? {} + return resourceFromAttributes({ + 'service.name': serviceName, + 'service.namespace': serviceNamespace, + 'deployment.environment': deploymentEnvironment, + 'telemetry.sdk.name': 'opentelemetry', + 'telemetry.sdk.language': 'typescript', + ...envAttributes, + }) +} diff --git a/strands-ts/src/telemetry/index.ts b/strands-ts/src/telemetry/index.ts new file mode 100644 index 0000000000..5e2975b511 --- /dev/null +++ b/strands-ts/src/telemetry/index.ts @@ -0,0 +1,37 @@ +/** + * OpenTelemetry telemetry support for Strands Agents SDK. + * + * This module provides `setupTracer()` to configure a NodeTracerProvider + * with OTLP or console exporters, and `setupMeter()` to configure a + * MeterProvider for OTEL metrics export. The Agent class handles tracing + * and metrics internally once telemetry is configured. + * + * @example Basic setup with OTLP exporter + * ```typescript + * import { Agent } from '@strands-agents/sdk' + * import { setupTracer, setupMeter } from '@strands-agents/sdk/telemetry' + * + * // Configure telemetry with OTLP exporter + * setupTracer({ exporters: { otlp: true } }) + * setupMeter({ exporters: { otlp: true } }) + * + * // Agent automatically traces invocations and emits metrics + * const agent = new Agent() + * ``` + * + * @example Using your own OpenTelemetry provider + * ```typescript + * import { Agent } from '@strands-agents/sdk' + * import { NodeTracerProvider } from '@opentelemetry/sdk-trace-node' + * + * // Set up your own provider + * const provider = new NodeTracerProvider() + * provider.register() + * + * // Agent automatically uses your provider via the global OTel API + * const agent = new Agent() + * ``` + */ + +export { setupTracer, getTracer, setupMeter, getMeter } from './config.js' +export type { TracerConfig, MeterConfig } from './config.js' diff --git a/strands-ts/src/telemetry/json.ts b/strands-ts/src/telemetry/json.ts new file mode 100644 index 0000000000..289ed63a69 --- /dev/null +++ b/strands-ts/src/telemetry/json.ts @@ -0,0 +1,24 @@ +/** + * Custom replacer for JSON.stringify that handles non-serializable types. + * Converts Date to ISO string and replaces binary data, functions, symbols, + * and BigInt with ''. + * + * @param _key - The property key (unused) + * @param value - The value to process + * @returns A JSON-safe value + */ +export function jsonReplacer(_key: string, value: unknown): unknown { + switch (true) { + case value instanceof Date: + return value.toISOString() + case typeof value === 'bigint': + case typeof value === 'function': + case typeof value === 'symbol': + case value instanceof ArrayBuffer: + case value instanceof Uint8Array: + case ArrayBuffer.isView(value): + return '' + default: + return value + } +} diff --git a/strands-ts/src/telemetry/meter.ts b/strands-ts/src/telemetry/meter.ts new file mode 100644 index 0000000000..a19fb8c254 --- /dev/null +++ b/strands-ts/src/telemetry/meter.ts @@ -0,0 +1,563 @@ +/** + * Agent loop metrics tracking. + * + * The {@link Meter} accumulates local metrics during agent invocation and + * provides them as a read-only {@link AgentMetrics} snapshot via the + * {@link Meter.metrics} getter for inclusion in {@link AgentResult}. + * + * When an OTEL MeterProvider is registered (via {@link setupMeter} or + * directly), the Meter also emits counters and histograms through the + * global OTEL metrics API, enabling export to OTLP backends. + */ + +import type { Counter, Histogram, Meter as OtelMeter } from '@opentelemetry/api' +import { metrics as otelMetrics } from '@opentelemetry/api' +import type { Usage, Metrics, ModelMetadataEventData } from '../models/streaming.js' +import { accumulateUsage, createEmptyUsage } from '../models/streaming.js' +import type { ToolUse } from '../tools/types.js' +import type { JSONSerializable } from '../types/json.js' +import { getServiceName } from './utils.js' + +/** + * Per-tool execution metrics. + */ +export interface ToolMetricsData { + /** + * Total number of calls to this tool. + */ + callCount: number + + /** + * Number of successful calls. + */ + successCount: number + + /** + * Number of failed calls. + */ + errorCount: number + + /** + * Total execution time in milliseconds. + */ + totalTime: number +} + +/** + * Per-cycle usage tracking. + */ +export interface AgentLoopMetricsData { + /** + * Unique identifier for this cycle. + */ + cycleId: string + + /** + * Duration of this cycle in milliseconds. + */ + duration: number + + /** + * Token usage for this cycle. + */ + usage: Usage +} + +/** + * Per-invocation metrics tracking. + */ +export interface InvocationMetricsData { + /** + * Cycle metrics for this invocation. + */ + cycles: AgentLoopMetricsData[] + + /** + * Accumulated token usage for this invocation. + */ + usage: Usage +} + +/** + * JSON-serializable representation of AgentMetrics. + */ +export interface AgentMetricsData { + /** + * Total number of agent loop cycles executed across all invocations. + */ + cycleCount: number + + /** + * Accumulated token usage across all model invocations. + */ + accumulatedUsage: Usage + + /** + * Accumulated performance metrics across all model invocations. + */ + accumulatedMetrics: Metrics + + /** + * Per-invocation metrics for recent invocations. + * Only the most recent 50 entries are retained. + */ + agentInvocations: InvocationMetricsData[] + + /** + * Per-tool execution metrics keyed by tool name. + */ + toolMetrics: Record + + /** + * The most recent input token count from the last model invocation. + * Represents the current context window utilization. + */ + latestContextSize?: number + + /** + * Projected context size for the next model call (inputTokens + outputTokens from the last call). + * Represents the baseline token count the next invocation will start with. + */ + projectedContextSize?: number + + /** + * Total duration of all cycles across all invocations in milliseconds. + */ + totalDuration?: number +} + +/** + * Options for recording tool usage. + */ +interface ToolUsageOptions { + /** + * The tool that was used. + */ + tool: ToolUse + + /** + * Execution duration in milliseconds. + */ + duration: number + + /** + * Whether the tool call succeeded. + */ + success: boolean +} + +/** + * Read-only snapshot of aggregated agent metrics. + * + * Returned by {@link Meter.metrics} and stored on {@link AgentResult}. + * Provides access to cycle counts, tool usage, token consumption, + * and per-invocation breakdowns. Supports serialization via {@link toJSON}. + * + * @example + * ```typescript + * const result = await agent.invoke('Hello') + * console.log(result.metrics?.cycleCount) + * console.log(result.metrics?.totalDuration) + * console.log(result.metrics?.accumulatedData) + * console.log(result.metrics?.toolMetrics) + * console.log(JSON.stringify(result.metrics)) + * ``` + */ +export class AgentMetrics implements JSONSerializable { + /** + * Total number of agent loop cycles executed across all invocations. + */ + readonly cycleCount: number + + /** + * Accumulated token usage across all model invocations. + */ + readonly accumulatedUsage: Usage + + /** + * Accumulated performance metrics across all model invocations. + */ + readonly accumulatedMetrics: Metrics + + /** + * Per-invocation metrics for recent invocations. + * Only the most recent 50 entries are retained to prevent unbounded memory growth. + * For full history, collect {@link latestAgentInvocation} from each {@link AgentResult}. + */ + readonly agentInvocations: InvocationMetricsData[] + + /** + * Per-tool execution metrics keyed by tool name. + */ + readonly toolMetrics: Record + + /** + * The most recent input token count from the last model invocation. + * Represents the current context window utilization. + * Returns `undefined` when no invocations have occurred. + */ + readonly latestContextSize: number | undefined + + /** + * Projected context size for the next model call (inputTokens + outputTokens from the last call). + * Represents the baseline token count the next invocation will start with. + * Returns `undefined` when no invocations have occurred. + */ + readonly projectedContextSize: number | undefined + + /** + * Total duration of all cycles across all invocations in milliseconds. + */ + readonly totalDuration: number + + constructor(data?: Partial) { + this.cycleCount = data?.cycleCount ?? 0 + this.accumulatedUsage = data?.accumulatedUsage ?? createEmptyUsage() + this.accumulatedMetrics = data?.accumulatedMetrics ?? { latencyMs: 0 } + this.agentInvocations = data?.agentInvocations ?? [] + this.toolMetrics = data?.toolMetrics ?? {} + this.latestContextSize = data?.latestContextSize + this.projectedContextSize = data?.projectedContextSize + this.totalDuration = data?.totalDuration ?? 0 + } + + /** + * The most recent agent invocation, or undefined if none exist. + */ + get latestAgentInvocation(): InvocationMetricsData | undefined { + return this.agentInvocations.length > 0 ? this.agentInvocations[this.agentInvocations.length - 1] : undefined + } + + /** + * Accumulated usage and performance metrics across all model invocations. + */ + get accumulatedData(): { usage: Usage; metrics: Metrics } { + return { usage: this.accumulatedUsage, metrics: this.accumulatedMetrics } + } + + /** + * Average cycle duration in milliseconds, or 0 if no cycles exist. + */ + get averageCycleTime(): number { + return this.cycleCount > 0 ? this.totalDuration / this.cycleCount : 0 + } + + /** + * Per-tool execution statistics with computed averages and rates. + */ + get toolUsage(): Record { + const usage: Record = {} + for (const [toolName, toolEntry] of Object.entries(this.toolMetrics)) { + usage[toolName] = { + ...toolEntry, + averageTime: toolEntry.callCount > 0 ? toolEntry.totalTime / toolEntry.callCount : 0, + successRate: toolEntry.callCount > 0 ? toolEntry.successCount / toolEntry.callCount : 0, + } + } + return usage + } + + /** + * Returns a JSON-serializable representation of all collected metrics. + * Called automatically by JSON.stringify(). + * + * @returns A plain object suitable for round-trip serialization + */ + toJSON(): AgentMetricsData { + return { + cycleCount: this.cycleCount, + accumulatedUsage: this.accumulatedUsage, + accumulatedMetrics: this.accumulatedMetrics, + agentInvocations: this.agentInvocations, + toolMetrics: this.toolMetrics, + totalDuration: this.totalDuration, + ...(this.latestContextSize !== undefined && { latestContextSize: this.latestContextSize }), + ...(this.projectedContextSize !== undefined && { projectedContextSize: this.projectedContextSize }), + } + } +} + +/** + * Maximum number of invocation history entries retained by the Meter. + * Prevents unbounded memory growth on long-lived Agent instances. + * Users who need full history can collect per-invocation metrics + * from successive AgentResult objects. + */ +const MAX_INVOCATION_HISTORY = 50 + +/** + * Accumulates local metrics during agent invocation. + * + * Tracks cycle counts, token usage, tool execution stats, and model latency. + * Use the {@link metrics} getter to obtain a read-only {@link AgentMetrics} + * snapshot for inclusion in {@link AgentResult}. + * + * When an OTEL MeterProvider is registered, the same data is also emitted + * as OTEL counters and histograms via the global metrics API. If no + * provider is registered the OTEL meter is a no-op and adds no overhead. + */ +export class Meter { + /** + * Number of agent loop cycles executed. + */ + private _cycleCount: number = 0 + + /** + * Accumulated token usage across all model invocations. + */ + private readonly _accumulatedUsage: Usage = createEmptyUsage() + + /** + * Accumulated performance metrics across all model invocations. + */ + private readonly _accumulatedMetrics: Metrics = { latencyMs: 0 } + + /** + * Per-invocation metrics. + */ + private readonly _agentInvocations: InvocationMetricsData[] = [] + + /** + * Per-tool execution metrics keyed by tool name. + */ + private readonly _toolMetrics: Record = {} + + /** + * The most recent input token count from the last model invocation. + */ + private _latestContextSize: number | undefined + + /** + * Projected context size for the next model call (inputTokens + outputTokens). + */ + private _projectedContextSize: number | undefined + + /** + * Running total of all cycle durations in milliseconds. + */ + private _totalDuration: number = 0 + + // OTEL instruments (no-op when no MeterProvider is registered) + private readonly _otelMeter: OtelMeter + private readonly _otelCycleCounter: Counter + private readonly _otelInvocationCounter: Counter + private readonly _otelCycleDuration: Histogram + private readonly _otelToolCallCounter: Counter + private readonly _otelToolErrorCounter: Counter + private readonly _otelToolDuration: Histogram + private readonly _otelInputTokens: Counter + private readonly _otelOutputTokens: Counter + private readonly _otelModelLatency: Histogram + private readonly _otelTimeToFirstToken: Histogram + + constructor() { + this._otelMeter = otelMetrics.getMeter(getServiceName()) + + this._otelCycleCounter = this._otelMeter.createCounter('gen_ai.agent.cycle.count', { + description: 'Number of agent loop cycles executed', + }) + this._otelInvocationCounter = this._otelMeter.createCounter('gen_ai.agent.invocation.count', { + description: 'Number of agent invocations', + }) + this._otelCycleDuration = this._otelMeter.createHistogram('gen_ai.agent.cycle.duration', { + description: 'Duration of agent loop cycles in milliseconds', + unit: 'ms', + }) + this._otelToolCallCounter = this._otelMeter.createCounter('gen_ai.agent.tool.call.count', { + description: 'Number of tool calls', + }) + this._otelToolErrorCounter = this._otelMeter.createCounter('gen_ai.agent.tool.error.count', { + description: 'Number of failed tool calls', + }) + this._otelToolDuration = this._otelMeter.createHistogram('gen_ai.agent.tool.duration', { + description: 'Duration of tool calls in milliseconds', + unit: 'ms', + }) + this._otelInputTokens = this._otelMeter.createCounter('gen_ai.agent.tokens.input', { + description: 'Input tokens consumed', + }) + this._otelOutputTokens = this._otelMeter.createCounter('gen_ai.agent.tokens.output', { + description: 'Output tokens consumed', + }) + this._otelModelLatency = this._otelMeter.createHistogram('gen_ai.agent.model.latency', { + description: 'Model invocation latency in milliseconds', + unit: 'ms', + }) + // OTel GenAI semconv requires seconds for this metric, unlike the SDK-internal histograms which use ms + this._otelTimeToFirstToken = this._otelMeter.createHistogram('gen_ai.server.time_to_first_token', { + description: 'Time to generate first token for successful responses', + unit: 's', + }) + } + + /** + * Begin tracking a new agent invocation. + * Creates a new InvocationMetricsData entry for per-invocation metrics. + * Evicts the oldest entry when the history exceeds MAX_INVOCATION_HISTORY. + */ + startNewInvocation(): void { + if (this._agentInvocations.length >= MAX_INVOCATION_HISTORY) { + this._agentInvocations.shift() + } + + this._agentInvocations.push({ + cycles: [], + usage: createEmptyUsage(), + }) + this._otelInvocationCounter.add(1) + } + + /** + * Start a new agent loop cycle. + * + * @returns The cycle id and start time + */ + startCycle(): { cycleId: string; startTime: number } { + this._cycleCount++ + this._otelCycleCounter.add(1) + + const cycleId = `cycle-${this._cycleCount}` + const startTime = Date.now() + + const latestInvocation = this._latestAgentInvocation + if (latestInvocation) { + latestInvocation.cycles.push({ + cycleId: cycleId, + duration: 0, + usage: createEmptyUsage(), + }) + } + + return { cycleId, startTime } + } + + /** + * End the current agent loop cycle and record its duration. + * + * @param startTime - The timestamp when the cycle started (milliseconds since epoch) + */ + endCycle(startTime: number): void { + const duration = Date.now() - startTime + this._otelCycleDuration.record(duration) + + this._totalDuration += duration + + const latestInvocation = this._latestAgentInvocation + if (latestInvocation) { + const cycles = latestInvocation.cycles + if (cycles.length > 0) { + cycles[cycles.length - 1]!.duration = duration + } + } + } + + /** + * Record metrics for a completed tool invocation. + * + * @param options - Tool usage recording options + */ + endToolCall(options: ToolUsageOptions): void { + const { tool, duration, success } = options + const toolName = tool.name + + if (!this._toolMetrics[toolName]) { + this._toolMetrics[toolName] = { callCount: 0, successCount: 0, errorCount: 0, totalTime: 0 } + } + + const toolEntry = this._toolMetrics[toolName]! + toolEntry.callCount++ + toolEntry.totalTime += duration + + const attrs = { 'gen_ai.tool.name': toolName } + this._otelToolCallCounter.add(1, attrs) + this._otelToolDuration.record(duration, attrs) + + if (success) { + toolEntry.successCount++ + } else { + toolEntry.errorCount++ + this._otelToolErrorCounter.add(1, attrs) + } + } + + /** + * Update loop-level metrics from a model response. + * + * Call this after each model invocation within a cycle to + * accumulate usage and latency. + * + * @param metadata - The metadata event from a model invocation, or undefined if unavailable + */ + updateCycle(metadata?: ModelMetadataEventData): void { + if (metadata) { + this._updateFromMetadata(metadata) + } + } + + /** + * Read-only snapshot of the accumulated metrics. + * Returns an AgentMetrics instance suitable for inclusion in AgentResult. + */ + get metrics(): AgentMetrics { + return new AgentMetrics({ + cycleCount: this._cycleCount, + accumulatedUsage: this._accumulatedUsage, + accumulatedMetrics: this._accumulatedMetrics, + agentInvocations: this._agentInvocations, + toolMetrics: this._toolMetrics, + totalDuration: this._totalDuration, + ...(this._latestContextSize !== undefined && { latestContextSize: this._latestContextSize }), + ...(this._projectedContextSize !== undefined && { projectedContextSize: this._projectedContextSize }), + }) + } + + /** + * The most recent agent invocation, or undefined if none exist. + */ + private get _latestAgentInvocation(): InvocationMetricsData | undefined { + return this._agentInvocations.length > 0 ? this._agentInvocations[this._agentInvocations.length - 1] : undefined + } + + /** + * Update accumulated usage and metrics from a model metadata event. + * + * @param metadata - The metadata event from a model invocation + */ + private _updateFromMetadata(metadata: ModelMetadataEventData): void { + if (metadata.usage) { + this._updateUsage(metadata.usage) + } + if (metadata.metrics) { + this._accumulatedMetrics.latencyMs += metadata.metrics.latencyMs + this._otelModelLatency.record(metadata.metrics.latencyMs) + + if (metadata.metrics.timeToFirstByteMs !== undefined && metadata.metrics.timeToFirstByteMs > 0) { + this._otelTimeToFirstToken.record(metadata.metrics.timeToFirstByteMs / 1000) + } + } + } + + /** + * Update the accumulated token usage with new usage data. + * + * @param usage - The usage data to accumulate + */ + private _updateUsage(usage: Usage): void { + accumulateUsage(this._accumulatedUsage, usage) + this._latestContextSize = usage.inputTokens + this._projectedContextSize = usage.inputTokens + usage.outputTokens + + this._otelInputTokens.add(usage.inputTokens) + this._otelOutputTokens.add(usage.outputTokens) + + const latestInvocation = this._latestAgentInvocation + if (latestInvocation) { + accumulateUsage(latestInvocation.usage, usage) + + const cycles = latestInvocation.cycles + if (cycles.length > 0) { + accumulateUsage(cycles[cycles.length - 1]!.usage, usage) + } + } + } +} diff --git a/strands-ts/src/telemetry/tracer.ts b/strands-ts/src/telemetry/tracer.ts new file mode 100644 index 0000000000..b1323a4bd8 --- /dev/null +++ b/strands-ts/src/telemetry/tracer.ts @@ -0,0 +1,1026 @@ +/** + * OpenTelemetry tracing and local execution trace management. + * + * This module provides tracing capabilities using OpenTelemetry, + * enabling trace data to be sent to OTLP endpoints. + * + * Uses a fully stateful approach via OpenTelemetry's context propagation. + * Parent-child relationships are established automatically through + * context.active(). Use context.with() to set a span as active before + * creating child spans. + * + * Lightweight in-memory LocalTrace trees are always collected regardless + * of OTel configuration and surfaced via AgentResult.traces. + * + * @example + * ```typescript + * const tracer = new Tracer() + * const parentSpan = tracer.startAgentSpan({ ... }) + * + * context.with(trace.setSpan(context.active(), parentSpan), async () => { + * const modelSpan = tracer.startModelInvokeSpan({ messages }) + * tracer.endModelInvokeSpan(modelSpan) + * }) + * + * tracer.endAgentSpan(parentSpan) + * ``` + */ + +import { context, SpanStatusCode, SpanKind, trace } from '@opentelemetry/api' +import type { Span, Tracer as OtelTracer, SpanOptions, AttributeValue } from '@opentelemetry/api' +import { logger } from '../logging/index.js' +import type { + EndAgentSpanOptions, + EndModelSpanOptions, + EndToolCallSpanOptions, + EndAgentLoopSpanOptions, + StartAgentSpanOptions, + StartModelInvokeSpanOptions, + StartToolCallSpanOptions, + StartAgentLoopSpanOptions, + StartMultiAgentSpanOptions, + EndMultiAgentSpanOptions, + StartNodeSpanOptions, + EndNodeSpanOptions, + Usage, + Metrics, +} from './types.js' +import type { ContentBlock, Message, SystemPrompt } from '../types/messages.js' +import type { JSONSerializable } from '../types/json.js' +import { jsonReplacer } from './json.js' +import { getServiceName } from './utils.js' + +/** + * JSON-serializable representation of LocalTrace. + */ +interface AgentTraceData { + id: string + name: string + parentId: string | null + startTime: number + endTime: number | null + duration: number + children: AgentTraceData[] + metadata: Record + message: Message | null +} + +/** + * Execution trace for performance analysis. + * Tracks timing and hierarchy of operations within the agent loop. + * Fields default to null for JSON serialization compatibility. + */ +export class AgentTrace implements JSONSerializable { + /** Unique identifier (UUID) for this trace. */ + readonly id: string + /** Human-readable display name (e.g., "Cycle 1", "Tool: calc", "stream_messages"). */ + readonly name: string + /** ID of the parent trace, if this trace is nested. Null for root traces. */ + readonly parentId: string | null + /** Start time in milliseconds since epoch. */ + readonly startTime: number + /** End time in milliseconds since epoch. Null until trace is ended. */ + endTime: number | null = null + /** Duration in milliseconds (endTime - startTime). */ + duration: number = 0 + /** Child traces nested under this trace. */ + readonly children: AgentTrace[] = [] + /** Additional metadata for this trace (e.g., cycleId, toolUseId, toolName). */ + readonly metadata: Record = {} + /** Message associated with this trace (e.g., model output). Null if not applicable. */ + message: Message | null = null + + /** + * @param name - Display name for this trace + * @param options - Optional configuration for parent and startTime + */ + constructor(name: string, options?: { parent?: AgentTrace; startTime?: number }) { + this.id = globalThis.crypto.randomUUID() + this.name = name + this.parentId = options?.parent?.id ?? null + this.startTime = options?.startTime ?? Date.now() + + if (options?.parent) { + options.parent.children.push(this) + } + } + + /** + * @param endTime - Optional end time in milliseconds since epoch + */ + end(endTime?: number): void { + this.endTime = endTime ?? Date.now() + this.duration = this.endTime - this.startTime + } + + toJSON(): AgentTraceData { + return { + id: this.id, + name: this.name, + parentId: this.parentId, + startTime: this.startTime, + endTime: this.endTime, + duration: this.duration, + children: this.children.map((child) => child.toJSON()), + metadata: this.metadata, + message: this.message, + } + } +} + +/** + * In-memory execution trace state, collected independently of OTel. + * Always active regardless of whether setupTracer() has been called. + */ +interface AgentTraceState { + /** Completed and in-progress cycle traces. */ + traces: AgentTrace[] + /** Current cycle trace, parents model and tool traces. */ + currentCycle?: AgentTrace | undefined + /** Current model invocation trace. */ + currentModel?: AgentTrace | undefined + /** Current tool call trace. */ + currentTool?: AgentTrace | undefined +} + +/** + * Manages both OpenTelemetry spans and local execution traces for agent operations. + * + * OTel spans are exported to external observability backends (Jaeger, X-Ray, etc.) + * when configured via setupTracer(). Local traces are lightweight, in-memory timing + * trees that are always collected regardless of OTel configuration and returned + * in AgentResult.traces for programmatic access. + * + * + */ +export class Tracer { + /** + * OpenTelemetry tracer instance obtained from the global API. + */ + private readonly _tracer: OtelTracer + + /** + * Whether to use latest experimental semantic conventions. + * + * Enabled via `OTEL_SEMCONV_STABILITY_OPT_IN=gen_ai_latest_experimental`. + * Changes attribute names (e.g., `gen_ai.system` → `gen_ai.provider.name`) and + * event formats (single `gen_ai.client.inference.operation.details` event vs + * separate per-message events). Enable when your observability backend supports + * newer GenAI conventions. + * + * @see https://opentelemetry.io/docs/specs/semconv/gen-ai/ + */ + private readonly _useLatestConventions: boolean + + /** + * Whether to include full tool JSON schemas in span attributes. + * + * Enabled via `OTEL_SEMCONV_STABILITY_OPT_IN=gen_ai_tool_definitions`. + * Useful for debugging tool configuration issues. Disabled by default to + * reduce span payload size and observability costs. + * + * Can be combined with other options: + * `OTEL_SEMCONV_STABILITY_OPT_IN=gen_ai_latest_experimental,gen_ai_tool_definitions` + */ + private readonly _includeToolDefinitions: boolean + + /** + * Custom attributes to include on all spans created by this tracer. + */ + private readonly _traceAttributes: Record + + /** Root span for the current agent invocation. */ + private _agentSpan: Span | undefined + + /** Span for the current agent loop cycle, used to parent model and tool spans. */ + private _loopSpan: Span | undefined + + /** Root span for the current multi-agent orchestration, used to parent node spans. */ + private _multiAgentSpan: Span | undefined + + /** + * Whether Langfuse is configured as the OTLP endpoint. + * Detected from OTEL_EXPORTER_OTLP_ENDPOINT, OTEL_EXPORTER_OTLP_TRACES_ENDPOINT, + * or LANGFUSE_BASE_URL environment variables. + */ + private readonly _isLangfuse: boolean + + /** In-memory execution trace state, collected independently of OTEL. */ + private readonly _traceState: AgentTraceState = { traces: [] } + + /** + * Initialize the tracer with OpenTelemetry configuration. + * Reads OTEL_SEMCONV_STABILITY_OPT_IN to determine convention version. + * Gets tracer from the global API to ensure ground truth - works correctly + * whether the user or Strands initialized the tracer provider. + * + * @param traceAttributes - Optional custom attributes to include on all spans + */ + constructor(traceAttributes?: Record) { + this._traceAttributes = traceAttributes ?? {} + + // Read semantic convention version from environment + const optInValues = Tracer._parseSemconvOptIn() + this._useLatestConventions = optInValues.has('gen_ai_latest_experimental') + this._includeToolDefinitions = optInValues.has('gen_ai_tool_definitions') + + this._isLangfuse = Tracer._detectLangfuse() + + // Get tracer from global API to ensure ground truth + this._tracer = trace.getTracer(getServiceName()) + } + + /** + * All local execution traces collected by this tracer. + */ + get localTraces(): AgentTrace[] { + return this._traceState.traces + } + + /** + * Start an agent invocation span. + * Returns the span which should be ended with endAgentSpan. + * Parents to the current active span from context.active(). + * + * @param options - Options for starting the agent span + */ + startAgentSpan(options: StartAgentSpanOptions): Span | null { + const { messages, agentName, agentId, modelId, tools, traceAttributes, toolsConfig, systemPrompt } = options + + // Reset local trace state for this invocation + this._traceState.traces = [] + this._traceState.currentCycle = undefined + this._traceState.currentModel = undefined + this._traceState.currentTool = undefined + + try { + const spanName = `invoke_agent ${agentName}` + const attributes = this._getCommonAttributes('invoke_agent') + attributes['gen_ai.agent.name'] = agentName + attributes['name'] = spanName + if (agentId) attributes['gen_ai.agent.id'] = agentId + if (modelId) attributes['gen_ai.request.model'] = modelId + + if (tools && tools.length > 0) { + const toolNames = tools.map((t) => t.name) + attributes['gen_ai.agent.tools'] = JSON.stringify(toolNames, jsonReplacer) + } + + if (this._includeToolDefinitions && toolsConfig) { + attributes['gen_ai.tool.definitions'] = JSON.stringify(toolsConfig, jsonReplacer) + } + + if (systemPrompt !== undefined) { + attributes['system_prompt'] = JSON.stringify(systemPrompt, jsonReplacer) + } + + const mergedAttributes = { ...attributes, ...this._traceAttributes, ...traceAttributes } + const span = this._startSpan({ name: spanName, attributes: mergedAttributes, spanKind: SpanKind.INTERNAL }) + + this._addEventMessages(span, messages) + + this._agentSpan = span + + return span + } catch (error) { + logger.warn(`error=<${error}> | failed to start agent span`) + return null + } + } + + /** + * End an agent invocation span. + * + * @param span - The span to end, or null if span creation failed + * @param options - Options for ending the span including response, error, and usage data + */ + endAgentSpan(span: Span | null, options: EndAgentSpanOptions = {}): void { + // Clear stale state from any previous invocation + this._agentSpan = undefined + this._loopSpan = undefined + + // Clear local trace state + this._traceState.currentCycle = undefined + this._traceState.currentModel = undefined + this._traceState.currentTool = undefined + + if (!span) return + + const { response, error, accumulatedUsage, stopReason } = options + + try { + const attributes: Record = {} + if (accumulatedUsage) this._setUsageAttributes(attributes, accumulatedUsage) + // Langfuse auto-generates "generation" observations for spans with token usage, + // which duplicates the token counts already reported on this agent span. + // Setting observation.type to "span" prevents Langfuse from creating that + // extra generation, avoiding double-counted tokens in dashboards. + // See https://github.com/langfuse/langfuse/issues/7549 + if (this._isLangfuse) attributes['langfuse.observation.type'] = 'span' + if (response !== undefined) this._addResponseEvent(span, response, stopReason) + + this._endSpan(span, attributes, error) + } catch (err) { + logger.warn(`error=<${err}> | failed to end agent span`) + } + } + + /** + * Start a model invocation span. + * Parents to the current active span from context.active(). + * + * @param options - Options for starting the model invocation span + */ + startModelInvokeSpan(options: StartModelInvokeSpanOptions): Span | null { + const { messages, modelId, systemPrompt } = options + + // Create local model trace as child of current cycle + this._traceState.currentModel = new AgentTrace( + 'stream_messages', + this._traceState.currentCycle ? { parent: this._traceState.currentCycle } : undefined + ) + + try { + const attributes = this._getCommonAttributes('chat') + if (modelId) attributes['gen_ai.request.model'] = modelId + + const span = this._startSpan({ + name: 'chat', + attributes, + spanKind: SpanKind.INTERNAL, + ...(this._loopSpan && { parentSpan: this._loopSpan }), + }) + this._addSystemPromptEvent(span, systemPrompt) + this._addEventMessages(span, messages) + + return span + } catch (error) { + logger.warn(`error=<${error}> | failed to start model invoke span`) + return null + } + } + + /** + * End a model invocation span. + * + * @param span - The span to end, or null if span creation failed + * @param options - Options for ending the span including usage, metrics, error, and output + */ + endModelInvokeSpan(span: Span | null, options: EndModelSpanOptions = {}): void { + // End local model trace and attach output message + if (this._traceState.currentModel) { + if (options.output) { + this._traceState.currentModel.message = options.output + } + this._traceState.currentModel.end() + this._traceState.currentModel = undefined + } + + if (!span) return + + const { usage, metrics, error, output, stopReason } = options + + try { + if (output !== undefined) this._addOutputEvent(span, output, stopReason) + + const attributes: Record = {} + if (usage) { + this._setUsageAttributes(attributes, usage) + if (metrics) this._setMetricsAttributes(attributes, metrics) + } + + this._endSpan(span, attributes, error) + } catch (err) { + logger.warn(`error=<${err}> | failed to end model invoke span`) + } + } + + /** + * Start a tool call span. + * Parents to the current active span from context.active(). + * + * @param options - Options for starting the tool call span + */ + startToolCallSpan(options: StartToolCallSpanOptions): Span | null { + const { tool } = options + + // Create local tool trace as child of current cycle + const toolTrace = new AgentTrace( + `Tool: ${tool.name}`, + this._traceState.currentCycle ? { parent: this._traceState.currentCycle } : undefined + ) + toolTrace.metadata.toolUseId = tool.toolUseId + toolTrace.metadata.toolName = tool.name + this._traceState.currentTool = toolTrace + + try { + const attributes = this._getCommonAttributes('execute_tool') + attributes['gen_ai.tool.name'] = tool.name + attributes['gen_ai.tool.call.id'] = tool.toolUseId + + const span = this._startSpan({ + name: `execute_tool ${tool.name}`, + attributes, + spanKind: SpanKind.INTERNAL, + ...(this._loopSpan && { parentSpan: this._loopSpan }), + }) + + if (this._useLatestConventions) { + this._addEvent(span, 'gen_ai.client.inference.operation.details', { + 'gen_ai.input.messages': JSON.stringify( + [ + { + role: 'tool', + parts: [{ type: 'tool_call', name: tool.name, id: tool.toolUseId, arguments: tool.input }], + }, + ], + jsonReplacer + ), + }) + } else { + this._addEvent(span, 'gen_ai.tool.message', { + role: 'tool', + content: JSON.stringify(tool.input, jsonReplacer), + id: tool.toolUseId, + }) + } + + return span + } catch (error) { + logger.warn(`error=<${error}> | failed to start tool call span`) + return null + } + } + + /** + * End a tool call span. + * + * @param span - The span to end, or null if span creation failed + * @param options - Options for ending the tool call span + */ + endToolCallSpan(span: Span | null, options: EndToolCallSpanOptions = {}): void { + // End local tool trace + if (this._traceState.currentTool) { + this._traceState.currentTool.end() + this._traceState.currentTool = undefined + } + + if (!span) return + + const { toolResult, error } = options + + try { + const attributes: Record = {} + + if (toolResult) { + const statusStr = typeof toolResult.status === 'string' ? toolResult.status : String(toolResult.status) + attributes['gen_ai.tool.status'] = statusStr + + if (this._useLatestConventions) { + this._addEvent(span, 'gen_ai.client.inference.operation.details', { + 'gen_ai.output.messages': JSON.stringify( + [ + { + role: 'tool', + parts: [{ type: 'tool_call_response', id: toolResult.toolUseId, response: toolResult.content }], + }, + ], + jsonReplacer + ), + }) + } else { + this._addEvent(span, 'gen_ai.choice', { + message: JSON.stringify(toolResult.content, jsonReplacer), + id: toolResult.toolUseId, + }) + } + } + + this._endSpan(span, attributes, error) + } catch (err) { + logger.warn(`error=<${err}> | failed to end tool call span`) + } + } + /** + * Start a multi-agent orchestration span. + * Parents to the current active span from context.active(). + * + * @param options - Options for starting the multi-agent span + * @returns The span, or null if span creation failed + */ + startMultiAgentSpan(options: StartMultiAgentSpanOptions): Span | null { + const { orchestratorId, orchestratorType, input, traceAttributes } = options + + try { + const spanName = `invoke_${orchestratorType} ${orchestratorId}` + const attributes: Record = { + ...this._getCommonAttributes(`invoke_${orchestratorType}`), + 'gen_ai.agent.name': orchestratorType, + 'gen_ai.agent.id': orchestratorId, + name: spanName, + } + if (input) attributes['gen_ai.agent.input'] = JSON.stringify(input, jsonReplacer) + + const mergedAttributes = { ...attributes, ...this._traceAttributes, ...traceAttributes } + const span = this._startSpan({ name: spanName, attributes: mergedAttributes, spanKind: SpanKind.INTERNAL }) + this._multiAgentSpan = span + return span + } catch (error) { + logger.warn(`error=<${error}> | failed to start multi-agent span`) + return null + } + } + + /** + * End a multi-agent orchestration span. + * + * @param span - The span to end, or null if span creation failed + * @param options - Options for ending the span including duration and error + */ + endMultiAgentSpan(span: Span | null, options: EndMultiAgentSpanOptions = {}): void { + this._multiAgentSpan = undefined + + if (!span) return + + try { + const attributes: Record = {} + if (options.duration !== undefined) attributes['gen_ai.agent.execution_time'] = options.duration + if (options.usage) this._setUsageAttributes(attributes, options.usage) + + this._endSpan(span, attributes, options.error) + } catch (err) { + logger.warn(`error=<${err}> | failed to end multi-agent span`) + } + } + + /** + * Start a node execution span. + * Parents to the current active span from context.active(). + * + * @param options - Options for starting the node span + * @returns The span, or null if span creation failed + */ + startNodeSpan(options: StartNodeSpanOptions): Span | null { + const { nodeId, nodeType, traceAttributes } = options + + try { + const spanName = `node ${nodeId}` + const attributes: Record = { + ...this._getCommonAttributes('execute_node'), + 'gen_ai.agent.id': nodeId, + 'gen_ai.agent.node_type': nodeType, + name: spanName, + } + + const mergedAttributes = { ...attributes, ...this._traceAttributes, ...traceAttributes } + return this._startSpan({ + name: spanName, + attributes: mergedAttributes, + spanKind: SpanKind.INTERNAL, + ...(this._multiAgentSpan && { parentSpan: this._multiAgentSpan }), + }) + } catch (error) { + logger.warn(`error=<${error}> | failed to start node span`) + return null + } + } + + /** + * End a node execution span. + * + * @param span - The span to end, or null if span creation failed + * @param options - Options for ending the span including status, duration, and error + */ + endNodeSpan(span: Span | null, options: EndNodeSpanOptions = {}): void { + if (!span) return + + try { + const attributes: Record = {} + if (options.status) attributes['gen_ai.agent.status'] = options.status + if (options.duration !== undefined) attributes['gen_ai.agent.execution_time'] = options.duration + if (options.usage) this._setUsageAttributes(attributes, options.usage) + + this._endSpan(span, attributes, options.error) + } catch (err) { + logger.warn(`error=<${err}> | failed to end node span`) + } + } + + /** + * Runs a callback with the given span set as the active OpenTelemetry context. + * Downstream code (e.g., MCP clients) can read the span from context.active() + * for distributed trace propagation. No-ops if span is null. + * + * @param span - The span to set as active, or null if span creation failed + * @param fn - The callback to run within the span's context + * @returns The return value of the callback + */ + withSpanContext(span: Span | null, fn: () => T): T { + if (!span) return fn() + return context.with(trace.setSpan(context.active(), span), fn) + } + + /** + * Start an agent loop cycle span. + * Parents to the current active span from context.active(). + * + * @param options - Options for starting the agent loop span + */ + startAgentLoopSpan(options: StartAgentLoopSpanOptions): Span | null { + const { cycleId, messages } = options + + // Create local cycle trace + const cycleNumber = this._traceState.traces.length + 1 + this._traceState.currentCycle = new AgentTrace(`Cycle ${cycleNumber}`) + this._traceState.currentCycle.metadata.cycleId = cycleId + this._traceState.traces.push(this._traceState.currentCycle) + + try { + const attributes: Record = { 'agent_loop.cycle_id': cycleId } + const span = this._startSpan({ + name: 'execute_agent_loop_cycle', + attributes, + ...(this._agentSpan && { parentSpan: this._agentSpan }), + }) + this._addEventMessages(span, messages) + this._loopSpan = span + return span + } catch (error) { + logger.warn(`error=<${error}> | failed to start agent loop cycle span`) + return null + } + } + + /** + * End an agent loop cycle span. + * + * @param span - The span to end, or null if span creation failed + * @param options - Options for ending the agent loop span + */ + endAgentLoopSpan(span: Span | null, options: EndAgentLoopSpanOptions = {}): void { + // End local cycle trace + if (this._traceState.currentCycle) { + this._traceState.currentCycle.end() + this._traceState.currentCycle = undefined + } + + if (!span) return + try { + this._endSpan(span, {}, options.error) + this._loopSpan = undefined + } catch (err) { + logger.warn(`error=<${err}> | failed to end agent loop cycle span`) + } + } + + /** + * Create a span parented to the current active context. + */ + private _startSpan(options: { + name: string + attributes?: Record + spanKind?: SpanKind + parentSpan?: Span + }): Span { + const spanOptions: SpanOptions = {} + + if (options.attributes) { + const otelAttributes: Record = {} + for (const [key, value] of Object.entries(options.attributes)) { + if (value !== undefined && value !== null) otelAttributes[key] = value + } + spanOptions.attributes = otelAttributes + } + + if (options.spanKind !== undefined) spanOptions.kind = options.spanKind + + const ctx = options.parentSpan ? trace.setSpan(context.active(), options.parentSpan) : context.active() + const span = this._tracer.startSpan(options.name, spanOptions, ctx) + + try { + span.setAttribute('gen_ai.event.start_time', new Date().toISOString()) + } catch (err) { + logger.warn(`error=<${err}> | failed to set start time attribute`) + } + + return span + } + + /** + * End a span with the given attributes and optional error. + */ + private _endSpan(span: Span, attributes?: Record, error?: Error): void { + try { + const endAttributes: Record = { 'gen_ai.event.end_time': new Date().toISOString() } + if (attributes) Object.assign(endAttributes, attributes) + + span.setAttributes(endAttributes) + + if (error) { + span.setStatus({ code: SpanStatusCode.ERROR, message: error.message }) + span.recordException(error) + } else { + span.setStatus({ code: SpanStatusCode.OK }) + } + + span.end() + } catch (err) { + logger.warn(`error=<${err}> | failed to end span`) + } + } + + /** + * Add an event to a span. + */ + private _addEvent(span: Span, eventName: string, eventAttributes?: Record): void { + try { + if (!eventAttributes) { + span.addEvent(eventName) + return + } + const otelAttributes: Record = {} + for (const [key, value] of Object.entries(eventAttributes)) { + if (value !== undefined && value !== null) otelAttributes[key] = value + } + span.addEvent(eventName, otelAttributes) + } catch (err) { + logger.warn(`error=<${err}>, event=<${eventName}> | failed to add span event`) + } + } + + /** + * Get common attributes based on semantic convention version. + * The attribute name changed between OTEL semconv versions: + * - Stable: 'gen_ai.system' + * - Latest experimental: 'gen_ai.provider.name' + */ + private _getCommonAttributes(operationName: string): Record { + const attributes: Record = { + 'gen_ai.operation.name': operationName, + } + + if (this._useLatestConventions) { + attributes['gen_ai.provider.name'] = getServiceName() + } else { + attributes['gen_ai.system'] = getServiceName() + } + + return attributes + } + + /** + * Add message events to a span. + * Uses different event formats based on semantic convention version: + * - Latest: Single 'gen_ai.client.inference.operation.details' event with all messages + * - Stable: Separate events per message (gen_ai.user.message, gen_ai.assistant.message, etc.) + */ + private _addEventMessages(span: Span, messages: Message[]): void { + try { + if (!Array.isArray(messages)) return + + if (this._useLatestConventions) { + const inputMessages = messages.map((m) => ({ + role: m.role, + parts: Tracer._mapContentBlocksToOtelParts(m.content), + })) + this._addEvent(span, 'gen_ai.client.inference.operation.details', { + 'gen_ai.input.messages': JSON.stringify(inputMessages, jsonReplacer), + }) + } else { + for (const message of messages) { + this._addEvent(span, this._getEventNameForMessage(message), { + content: JSON.stringify(message.content, jsonReplacer), + }) + } + } + } catch (err) { + logger.warn(`error=<${err}> | failed to add message events`) + } + } + + /** + * Get the event name for a message based on its type. + */ + private _getEventNameForMessage(message: Message): string { + if (message.role === 'user' && Array.isArray(message.content)) { + for (const block of message.content) { + if (block && typeof block === 'object' && 'type' in block && block.type === 'toolResultBlock') { + return 'gen_ai.tool.message' + } + } + } + + if (message.role === 'user') return 'gen_ai.user.message' + if (message.role === 'assistant') return 'gen_ai.assistant.message' + return 'gen_ai.message' + } + + /** + * Set usage attributes on an attributes object. + * Sets both legacy (prompt_tokens/completion_tokens) and new (input_tokens/output_tokens) + * attribute names for compatibility with different OTEL backends. + */ + private _setUsageAttributes(attributes: Record, usage: Usage): void { + attributes['gen_ai.usage.prompt_tokens'] = usage.inputTokens + attributes['gen_ai.usage.input_tokens'] = usage.inputTokens + attributes['gen_ai.usage.completion_tokens'] = usage.outputTokens + attributes['gen_ai.usage.output_tokens'] = usage.outputTokens + attributes['gen_ai.usage.total_tokens'] = usage.totalTokens + + if ((usage.cacheReadInputTokens ?? 0) > 0) { + attributes['gen_ai.usage.cache_read_input_tokens'] = usage.cacheReadInputTokens! + } + if ((usage.cacheWriteInputTokens ?? 0) > 0) { + attributes['gen_ai.usage.cache_write_input_tokens'] = usage.cacheWriteInputTokens! + } + } + + /** + * Set metrics attributes on an attributes object. + */ + private _setMetricsAttributes(attributes: Record, metrics: Metrics): void { + if (metrics.latencyMs !== undefined && metrics.latencyMs > 0) { + attributes['gen_ai.server.request.duration'] = metrics.latencyMs + } + } + + /** + * Add response event to a span. + */ + private _addResponseEvent(span: Span, response: Message, stopReason?: string): void { + try { + const finishReason = stopReason || 'end_turn' + + const textParts: string[] = [] + for (const block of response.content) { + if (block.type === 'textBlock') { + textParts.push(block.text) + } + } + const messageText = textParts.join('\n') + + if (this._useLatestConventions) { + this._addEvent(span, 'gen_ai.client.inference.operation.details', { + 'gen_ai.output.messages': JSON.stringify( + [{ role: 'assistant', parts: [{ type: 'text', content: messageText }], finish_reason: finishReason }], + jsonReplacer + ), + }) + } else { + this._addEvent(span, 'gen_ai.choice', { message: messageText, finish_reason: finishReason }) + } + } catch (err) { + logger.warn(`error=<${err}> | failed to add response event`) + } + } + + /** + * Add output event to a span for model invocation. + */ + private _addOutputEvent(span: Span, message: Message, stopReason?: string): void { + try { + const finishReason = stopReason || 'unknown' + + if (this._useLatestConventions) { + this._addEvent(span, 'gen_ai.client.inference.operation.details', { + 'gen_ai.output.messages': JSON.stringify( + [ + { + role: message.role, + parts: Tracer._mapContentBlocksToOtelParts(message.content), + finish_reason: finishReason, + }, + ], + jsonReplacer + ), + }) + } else { + this._addEvent(span, 'gen_ai.choice', { + finish_reason: finishReason, + message: JSON.stringify(Tracer._mapContentBlocksToStableFormat(message.content), jsonReplacer), + }) + } + } catch (err) { + logger.warn(`error=<${err}> | failed to add output event`) + } + } + + /** + * Parse the OTEL_SEMCONV_STABILITY_OPT_IN environment variable. + */ + private static _parseSemconvOptIn(): Set { + const optInEnv = globalThis.process?.env?.OTEL_SEMCONV_STABILITY_OPT_IN ?? '' + return new Set( + optInEnv + .split(',') + .map((value) => value.trim()) + .filter((value) => value.length > 0) + ) + } + + /** + * Detect whether Langfuse is configured as the OTLP endpoint. + * Checks OTEL_EXPORTER_OTLP_ENDPOINT, OTEL_EXPORTER_OTLP_TRACES_ENDPOINT, + * and LANGFUSE_BASE_URL environment variables. + */ + private static _detectLangfuse(): boolean { + const env = globalThis.process?.env + if (!env) return false + + if (env.LANGFUSE_BASE_URL) return true + + const otlpEndpoint = env.OTEL_EXPORTER_OTLP_ENDPOINT ?? '' + const tracesEndpoint = env.OTEL_EXPORTER_OTLP_TRACES_ENDPOINT ?? '' + return otlpEndpoint.includes('langfuse') || tracesEndpoint.includes('langfuse') + } + + /** + * Emit system prompt as a span event per OTel GenAI semantic conventions. + * In stable mode, emits a `gen_ai.system.message` event. + * In latest experimental mode, emits `gen_ai.system_instructions` on the + * `gen_ai.client.inference.operation.details` event. + * + * @param span - The span to add the event to + * @param systemPrompt - The system prompt provided to the model + */ + private _addSystemPromptEvent(span: Span, systemPrompt?: SystemPrompt): void { + if (systemPrompt === undefined) return + + if (this._useLatestConventions) { + const parts = Tracer._mapSystemPromptToOtelParts(systemPrompt) + this._addEvent(span, 'gen_ai.client.inference.operation.details', { + 'gen_ai.system_instructions': JSON.stringify(parts, jsonReplacer), + }) + } else { + // Normalize string prompts to an array of text blocks for consistent format + const blocks = typeof systemPrompt === 'string' ? [{ text: systemPrompt }] : systemPrompt + this._addEvent(span, 'gen_ai.system.message', { + content: JSON.stringify(blocks, jsonReplacer), + }) + } + } + + /** + * Map a system prompt to OTEL parts format (latest conventions). + * Handles both string prompts and SystemContentBlock arrays. + */ + private static _mapSystemPromptToOtelParts(systemPrompt: SystemPrompt): Record[] { + if (typeof systemPrompt === 'string') { + return [{ type: 'text', content: systemPrompt }] + } + return systemPrompt.map((block) => { + switch (block.type) { + case 'textBlock': + return { type: 'text', content: block.text } + case 'cachePointBlock': + return { type: 'cache_point', cacheType: block.cacheType } + case 'guardContentBlock': + return { type: 'guard_content', text: block.text, image: block.image } + } + }) + } + + /** + * Map content blocks to OTEL parts format (latest conventions). + * Converts SDK content block types to OTEL semantic convention format. + */ + private static _mapContentBlocksToOtelParts(contentBlocks: ContentBlock[]): Record[] { + if (!Array.isArray(contentBlocks)) return [] + + return contentBlocks.map((block) => { + switch (block.type) { + case 'textBlock': + return { type: 'text', content: block.text } + case 'toolUseBlock': + return { type: 'tool_call', name: block.name, id: block.toolUseId, arguments: block.input } + case 'toolResultBlock': + return { type: 'tool_call_response', id: block.toolUseId, response: block.content } + default: + return { type: block.type } + } + }) + } + + /** + * Map content blocks to stable format (older conventions). + * Simplifies content blocks to a minimal structure for legacy OTEL backends. + */ + private static _mapContentBlocksToStableFormat(contentBlocks: ContentBlock[]): unknown[] { + if (!Array.isArray(contentBlocks)) return [] + + return contentBlocks + .map((block) => { + switch (block.type) { + case 'textBlock': + return { text: block.text } + case 'toolUseBlock': + return { type: 'toolUse', name: block.name, toolUseId: block.toolUseId, input: block.input } + case 'toolResultBlock': + return { type: 'toolResult', toolUseId: block.toolUseId, content: block.content } + default: + return null + } + }) + .filter(Boolean) + } +} diff --git a/strands-ts/src/telemetry/types.ts b/strands-ts/src/telemetry/types.ts new file mode 100644 index 0000000000..ffa3c4e337 --- /dev/null +++ b/strands-ts/src/telemetry/types.ts @@ -0,0 +1,164 @@ +/** + * Type definitions for OpenTelemetry telemetry support. + */ + +import type { AttributeValue } from '@opentelemetry/api' +import type { Message, SystemPrompt, ToolResultBlock } from '../types/messages.js' +import type { InvokeArgs } from '../types/agent.js' +import type { ToolSpec, ToolUse } from '../tools/types.js' +import type { Usage, Metrics } from '../models/streaming.js' + +// Re-export for convenience +export type { Usage, Metrics } + +/** + * Options for starting an agent span. + */ +export interface StartAgentSpanOptions { + /** Conversation messages to record as span events. */ + messages: Message[] + /** Name of the agent being invoked. */ + agentName: string + /** Unique identifier for the agent instance. */ + agentId?: string + /** Model identifier used by the agent. */ + modelId?: string + /** List of tools available to the agent. */ + tools?: { name: string }[] + /** Custom attributes to merge onto the span. */ + traceAttributes?: Record + /** Tool configuration map, included when gen_ai_tool_definitions opt-in is enabled. */ + toolsConfig?: Record + /** System prompt provided to the agent. */ + systemPrompt?: SystemPrompt +} + +/** + * Options for ending an agent span. + */ +export interface EndAgentSpanOptions { + /** Final response from the agent. */ + response?: Message + /** Error that caused the agent invocation to fail. */ + error?: Error + /** Accumulated token usage across all model calls in this invocation. */ + accumulatedUsage?: Usage + /** Reason the agent stopped (e.g., 'end_turn', 'tool_use'). */ + stopReason?: string +} + +/** + * Options for starting a model invocation span. + */ +export interface StartModelInvokeSpanOptions { + /** Conversation messages sent to the model. */ + messages: Message[] + /** Model identifier being invoked. */ + modelId?: string + /** System prompt provided to the model for this invocation. */ + systemPrompt?: SystemPrompt +} + +/** + * Options for ending a model invocation span. + */ +export interface EndModelSpanOptions { + /** Token usage from this model call. */ + usage?: Usage + /** Performance metrics from this model call. */ + metrics?: Metrics + /** Error that caused the model invocation to fail. */ + error?: Error + /** Message-like object with 'content' and 'role' properties. */ + output?: Message + /** Reason the model stopped generating (e.g., 'end_turn', 'tool_use'). */ + stopReason?: string +} + +/** + * Options for starting a tool call span. + */ +export interface StartToolCallSpanOptions { + /** Tool use request containing name, id, and input arguments. */ + tool: ToolUse +} + +/** + * Options for ending a tool call span. + */ +export interface EndToolCallSpanOptions { + /** Result returned by the tool execution. */ + toolResult?: ToolResultBlock + /** Error that caused the tool call to fail. */ + error?: Error +} + +/** + * Options for starting an agent loop cycle span. + */ +export interface StartAgentLoopSpanOptions { + /** Unique identifier for this loop cycle. */ + cycleId: string + /** Conversation messages at the start of this cycle. */ + messages: Message[] +} + +/** + * Options for ending an agent loop cycle span. + */ +export interface EndAgentLoopSpanOptions { + /** Error that caused the loop cycle to fail. */ + error?: Error +} + +/** + * Options for starting a multi-agent orchestration span. + */ +export interface StartMultiAgentSpanOptions { + /** Unique identifier for the orchestrator instance. */ + orchestratorId: string + /** Orchestration pattern type. */ + orchestratorType: 'graph' | 'swarm' + /** Input task or prompt passed to the orchestrator. */ + input?: InvokeArgs | undefined + /** Custom attributes to merge onto the span. */ + traceAttributes?: Record | undefined +} + +/** + * Options for ending a multi-agent orchestration span. + */ +export interface EndMultiAgentSpanOptions { + /** Error that caused the orchestration to fail. */ + error?: Error | undefined + /** Total duration of the orchestration in milliseconds. */ + duration?: number | undefined + /** Aggregated token usage across all node executions. */ + usage?: Usage | undefined +} + +/** + * Options for starting a node execution span. + */ +export interface StartNodeSpanOptions { + /** Unique identifier for the node. */ + nodeId: string + /** Node type identifier (e.g., 'agentNode', 'multiAgentNode'). */ + nodeType: string + /** Custom attributes to merge onto the span. */ + traceAttributes?: Record | undefined +} + +/** + * Options for ending a node execution span. + */ +export interface EndNodeSpanOptions { + /** Final status of the node execution. */ + status?: string | undefined + /** Duration of the node execution in milliseconds. */ + duration?: number | undefined + /** Token usage from the node execution. */ + usage?: Usage | undefined + /** Error that caused the node execution to fail. */ + error?: Error | undefined +} diff --git a/strands-ts/src/telemetry/utils.ts b/strands-ts/src/telemetry/utils.ts new file mode 100644 index 0000000000..1195291407 --- /dev/null +++ b/strands-ts/src/telemetry/utils.ts @@ -0,0 +1,14 @@ +/** + * Shared telemetry utilities. + */ + +const DEFAULT_SERVICE_NAME = 'strands-agents' + +/** + * Get the service name, respecting the OTEL_SERVICE_NAME environment variable. + * + * @returns The service name from OTEL_SERVICE_NAME or the default 'strands-agents' + */ +export function getServiceName(): string { + return globalThis.process?.env?.OTEL_SERVICE_NAME || DEFAULT_SERVICE_NAME +} diff --git a/strands-ts/src/tools/__tests__/structured-output-tool.test.ts b/strands-ts/src/tools/__tests__/structured-output-tool.test.ts new file mode 100644 index 0000000000..db75508654 --- /dev/null +++ b/strands-ts/src/tools/__tests__/structured-output-tool.test.ts @@ -0,0 +1,105 @@ +import { describe, expect, it, vi } from 'vitest' +import { z } from 'zod' +import { StructuredOutputTool, STRUCTURED_OUTPUT_TOOL_NAME } from '../structured-output-tool.js' +import { JsonBlock, TextBlock, ToolResultBlock } from '../../types/messages.js' +import { createMockContext } from '../../__fixtures__/tool-helpers.js' +import type { JSONValue } from '../../types/json.js' + +/** Helper to run the tool and return the final ToolResultBlock. */ +async function runTool(tool: StructuredOutputTool, input: JSONValue): Promise { + const context = createMockContext({ name: STRUCTURED_OUTPUT_TOOL_NAME, toolUseId: 'tool-1', input }) + const result = await tool.stream(context).next() + return result.value as ToolResultBlock +} + +describe('StructuredOutputTool', () => { + describe('constructor', () => { + it('builds tool spec from schema', () => { + const tool = new StructuredOutputTool(z.object({ name: z.string() }).describe('A person schema')) + + expect(tool.name).toBe(STRUCTURED_OUTPUT_TOOL_NAME) + expect(tool.toolSpec.name).toBe(STRUCTURED_OUTPUT_TOOL_NAME) + expect(tool.toolSpec.inputSchema).toBeDefined() + expect(tool.description).toContain('MUST only be invoked') + expect(tool.description).toContain('A person schema') + }) + + it('uses base description when schema has no description', () => { + const tool = new StructuredOutputTool(z.object({ name: z.string() })) + + expect(tool.description).toContain('MUST only be invoked') + expect(tool.description).not.toContain('') + }) + }) + + describe('stream', () => { + it('returns success with validated JSON for valid input', async () => { + const tool = new StructuredOutputTool(z.object({ name: z.string(), age: z.number() })) + const result = await runTool(tool, { name: 'John', age: 30 }) + + expect(result).toStrictEqual( + new ToolResultBlock({ + toolUseId: 'tool-1', + status: 'success', + content: [new JsonBlock({ json: { name: 'John', age: 30 } })], + }) + ) + }) + + it('returns error with ZodError for invalid input', async () => { + const tool = new StructuredOutputTool(z.object({ name: z.string(), age: z.number() })) + const result = await runTool(tool, { name: 'John', age: 'invalid' }) + + expect(result.status).toBe('error') + expect(result.error).toBeInstanceOf(z.ZodError) + expect((result.content[0] as TextBlock).text).toContain('age') + }) + + it('includes validation details for multiple fields', async () => { + const tool = new StructuredOutputTool(z.object({ name: z.string(), age: z.number(), email: z.string().email() })) + const result = await runTool(tool, { name: 123, age: 'invalid', email: 'not-email' }) + + expect(result.status).toBe('error') + const errorText = (result.content[0] as TextBlock).text + expect(errorText).toContain('name') + expect(errorText).toContain('age') + expect(errorText).toContain('email') + }) + + it('validates nested objects', async () => { + const tool = new StructuredOutputTool(z.object({ user: z.object({ name: z.string(), age: z.number() }) })) + const result = await runTool(tool, { user: { name: 'John', age: 30 } }) + + expect(result.status).toBe('success') + expect((result.content[0] as JsonBlock).json).toEqual({ user: { name: 'John', age: 30 } }) + }) + + it('validates arrays', async () => { + const tool = new StructuredOutputTool(z.object({ items: z.array(z.string()) })) + const result = await runTool(tool, { items: ['a', 'b', 'c'] }) + + expect(result.status).toBe('success') + expect((result.content[0] as JsonBlock).json).toEqual({ items: ['a', 'b', 'c'] }) + }) + + it('handles optional fields', async () => { + const tool = new StructuredOutputTool(z.object({ name: z.string(), age: z.number().optional() })) + const result = await runTool(tool, { name: 'John' }) + + expect(result.status).toBe('success') + expect((result.content[0] as JsonBlock).json).toEqual({ name: 'John' }) + }) + + it('returns error result for non-ZodError exceptions', async () => { + const tool = new StructuredOutputTool(z.object({ value: z.string() })) + vi.spyOn(tool['_schema'], 'parse').mockImplementation(() => { + throw new Error('unexpected parse error') + }) + const result = await runTool(tool, { value: 'valid' }) + + expect(result.status).toBe('error') + expect(result.error).toBeInstanceOf(Error) + expect((result.content[0] as TextBlock).text).toContain('unexpected parse error') + }) + }) +}) diff --git a/strands-ts/src/tools/__tests__/tool-factory.test.ts b/strands-ts/src/tools/__tests__/tool-factory.test.ts new file mode 100644 index 0000000000..d413c3353c --- /dev/null +++ b/strands-ts/src/tools/__tests__/tool-factory.test.ts @@ -0,0 +1,113 @@ +import { describe, expect, it } from 'vitest' +import { z } from 'zod' +import { tool } from '../tool-factory.js' +import { Tool } from '../tool.js' + +describe('tool factory', () => { + describe('dispatch logic', () => { + it('creates ZodTool when inputSchema is a Zod type', () => { + const myTool = tool({ + name: 'zod', + description: 'Zod', + inputSchema: z.object({ x: z.string() }), + callback: (input) => input.x, + }) + + // ZodTool generates JSON schema from Zod with additionalProperties: false + expect(myTool.toolSpec.inputSchema).toHaveProperty('additionalProperties', false) + }) + + it('creates FunctionTool when inputSchema is a plain object', () => { + const schema = { type: 'object' as const, properties: { x: { type: 'string' as const } } } + const myTool = tool({ + name: 'json', + description: 'JSON', + inputSchema: schema, + callback: () => 'ok', + }) + + // JSON schema is passed through as-is + expect(myTool.toolSpec.inputSchema).toStrictEqual(schema) + }) + + it('creates FunctionTool when inputSchema is omitted', () => { + const myTool = tool({ + name: 'noSchema', + description: 'No schema', + callback: () => 'ok', + }) + + expect(myTool.toolSpec.inputSchema).toStrictEqual({ + type: 'object', + properties: {}, + additionalProperties: false, + }) + }) + }) + + describe('FunctionTool invoke()', () => { + it('handles synchronous callback', async () => { + const myTool = tool({ + name: 'sync', + description: 'Sync', + inputSchema: { type: 'object' }, + callback: (input) => { + const { a, b } = input as { a: number; b: number } + return a + b + }, + }) + + expect(await myTool.invoke({ a: 5, b: 3 })).toBe(8) + }) + + it('handles promise callback', async () => { + const myTool = tool({ + name: 'async', + description: 'Async', + inputSchema: { type: 'object' }, + callback: async (input) => `Result: ${(input as { value: string }).value}`, + }) + + expect(await myTool.invoke({ value: 'test' })).toBe('Result: test') + }) + + it('handles async generator callback', async () => { + const myTool = tool({ + name: 'gen', + description: 'Generator', + inputSchema: { type: 'object' }, + callback: async function* (input) { + const { count } = input as { count: number } + for (let i = 1; i <= count; i++) { + yield i + } + return 0 + }, + }) + + expect(await myTool.invoke({ count: 3 })).toBe(0) + }) + + it('passes instanceof Tool check', () => { + const myTool = tool({ + name: 'test', + description: 'test', + inputSchema: { type: 'object' }, + callback: () => 'ok', + }) + + expect(myTool instanceof Tool).toBe(true) + }) + + it('defaults description to empty string', () => { + const myTool = tool({ + name: 'test', + description: '', + inputSchema: { type: 'object' }, + callback: () => 'ok', + }) + + expect(myTool.description).toBe('') + }) + }) +}) diff --git a/strands-ts/src/tools/__tests__/tool.test.ts b/strands-ts/src/tools/__tests__/tool.test.ts new file mode 100644 index 0000000000..c6d7447e77 --- /dev/null +++ b/strands-ts/src/tools/__tests__/tool.test.ts @@ -0,0 +1,1231 @@ +import { describe, expect, it } from 'vitest' +import { FunctionTool } from '../function-tool.js' +import { Tool, ToolStreamEvent, isValidToolName } from '../tool.js' +import type { ToolContext } from '../tool.js' +import type { JSONValue } from '../../types/json.js' +import { createMockContext } from '../../__fixtures__/tool-helpers.js' + +import { collectGenerator } from '../../__fixtures__/model-test-helpers.js' + +describe('isValidToolName', () => { + it.each([ + ['simple', true], + ['with_underscore', true], + ['with-hyphen', true], + ['Mixed-Case_123', true], + ['a', true], + ['a'.repeat(64), true], + ])('accepts %s', (name, expected) => { + expect(isValidToolName(name)).toBe(expected) + }) + + it.each([ + ['', 'empty string'], + ['a'.repeat(65), 'over 64 chars'], + ['has space', 'space'], + ['has.dot', 'dot'], + ['has/slash', 'slash'], + ['has:colon', 'colon'], + ['emoji🚀', 'non-ascii'], + ])('rejects %s (%s)', (name) => { + expect(isValidToolName(name)).toBe(false) + }) +}) + +describe('FunctionTool', () => { + describe('properties', () => { + it('has a non-empty toolName', () => { + const tool = new FunctionTool({ + name: 'testTool', + description: 'Test description', + inputSchema: { type: 'object' }, + callback: (): string => 'result', + }) + expect(tool.name).toBeTruthy() + expect(typeof tool.name).toBe('string') + expect(tool.name.length).toBeGreaterThan(0) + expect(tool.name).toBe('testTool') + }) + + it('has a non-empty description', () => { + const tool = new FunctionTool({ + name: 'testTool', + description: 'Test description', + inputSchema: { type: 'object' }, + callback: (): string => 'result', + }) + expect(tool.description).toBeTruthy() + expect(typeof tool.description).toBe('string') + expect(tool.description.length).toBeGreaterThan(0) + expect(tool.description).toBe('Test description') + }) + + it('has a valid toolSpec', () => { + const inputSchema = { + type: 'object' as const, + properties: { + value: { type: 'string' as const }, + }, + } + const tool = new FunctionTool({ + name: 'testTool', + description: 'Test description', + inputSchema, + callback: (): string => 'result', + }) + + // Verify entire toolSpec object at once + expect(tool.toolSpec).toEqual({ + name: 'testTool', + description: 'Test description', + inputSchema, + }) + }) + + it('has matching toolName and toolSpec.name', () => { + const tool = new FunctionTool({ + name: 'testTool', + description: 'Test description', + inputSchema: { type: 'object' }, + callback: (): string => 'result', + }) + expect(tool.name).toBe(tool.toolSpec.name) + }) + + it('has matching description and toolSpec.description', () => { + const tool = new FunctionTool({ + name: 'testTool', + description: 'Test description', + inputSchema: { type: 'object' }, + callback: (): string => 'result', + }) + expect(tool.description).toBe(tool.toolSpec.description) + }) + }) + + describe('stream method', () => { + describe('with synchronous callback', () => { + it('wraps return value in ToolResult', async () => { + const tool = new FunctionTool({ + name: 'syncTool', + description: 'Returns synchronous value', + inputSchema: { type: 'object', properties: { value: { type: 'number' } } }, + callback: (input: unknown): number => { + const { value } = input as { value: number } + return value * 2 + }, + }) + + const toolUse = { + name: 'syncTool', + toolUseId: 'test-sync-1', + input: { value: 5 }, + } + const context = createMockContext(toolUse) + + const { items: streamEvents, result } = await collectGenerator(tool.stream(context)) + + // No stream events for sync callback + expect(streamEvents.length).toBe(0) + + // Verify entire result with actual calculated value + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-sync-1', + status: 'success', + content: [ + expect.objectContaining({ + type: 'textBlock', + text: '10', // 5 * 2 = 10 (converted to string) + }), + ], + }) + }) + + it('handles string return values', async () => { + const tool = new FunctionTool({ + name: 'stringTool', + description: 'Returns string', + inputSchema: { type: 'object' }, + callback: (): string => 'Hello, World!', + }) + + const toolUse = { + name: 'stringTool', + toolUseId: 'test-string', + input: {}, + } + const context = createMockContext(toolUse) + + const { items: streamEvents, result } = await collectGenerator(tool.stream(context)) + + expect(streamEvents.length).toBe(0) + + // Verify entire result object + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-string', + status: 'success', + content: [ + expect.objectContaining({ + type: 'textBlock', + text: 'Hello, World!', + }), + ], + }) + }) + + it('handles object return values', async () => { + const tool = new FunctionTool({ + name: 'objectTool', + description: 'Returns object', + inputSchema: { type: 'object' }, + callback: (): { key: string; count: number } => ({ key: 'value', count: 42 }), + }) + + const toolUse = { + name: 'objectTool', + toolUseId: 'test-object', + input: {}, + } + const context = createMockContext(toolUse) + + const { items: streamEvents, result } = await collectGenerator(tool.stream(context)) + + expect(streamEvents.length).toBe(0) + + // Verify entire result object + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-object', + status: 'success', + content: [ + expect.objectContaining({ + type: 'jsonBlock', + json: { key: 'value', count: 42 }, + }), + ], + }) + }) + + it('treats objects with extra keys beyond a content block key as JSON', async () => { + const tool = new FunctionTool({ + name: 'extraKeyTool', + description: 'Returns object with text key plus extra keys', + inputSchema: { type: 'object' }, + callback: (): { text: string; extra: string } => ({ text: 'abc', extra: '123' }), + }) + + const toolUse = { name: 'extraKeyTool', toolUseId: 'test-extra', input: {} } + const { result } = await collectGenerator(tool.stream(createMockContext(toolUse))) + + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-extra', + status: 'success', + content: [ + expect.objectContaining({ + type: 'jsonBlock', + json: { text: 'abc', extra: '123' }, + }), + ], + }) + }) + + it('passes input to callback exactly as provided to stream', async () => { + const inputData = { name: 'test', value: 42, nested: { key: 'value' } } + let receivedInput: unknown + + const tool = new FunctionTool({ + name: 'inputTool', + description: 'Captures input', + inputSchema: { type: 'object' }, + callback: (input: unknown): string => { + receivedInput = input + return 'success' + }, + }) + + const toolUse = { + name: 'inputTool', + toolUseId: 'test-input', + input: inputData, + } + + await collectGenerator(tool.stream(createMockContext(toolUse))) + + expect(receivedInput).toEqual(inputData) + }) + + it('handles null return values correctly', async () => { + const tool = new FunctionTool({ + name: 'nullTool', + description: 'Returns null', + inputSchema: { type: 'object' }, + callback: (): null => null, + }) + + const { result } = await collectGenerator( + tool.stream(createMockContext({ name: 'nullTool', toolUseId: 'test-null', input: {} })) + ) + + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-null', + status: 'success', + content: [ + expect.objectContaining({ + type: 'textBlock', + text: '', + }), + ], + }) + }) + + it('handles undefined return values correctly', async () => { + const tool = new FunctionTool({ + name: 'undefinedTool', + description: 'Returns undefined', + inputSchema: { type: 'object' }, + // @ts-expect-error we're intentionally testing a type violation + callback: (): undefined => undefined, + }) + + const { result } = await collectGenerator( + tool.stream(createMockContext({ name: 'undefinedTool', toolUseId: 'test-undefined', input: {} })) + ) + + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-undefined', + status: 'success', + content: [ + expect.objectContaining({ + type: 'textBlock', + text: '', + }), + ], + }) + }) + + it('handles boolean return values as text content', async () => { + const trueTool = new FunctionTool({ + name: 'trueTool', + description: 'Returns true', + inputSchema: { type: 'object' }, + callback: (): boolean => true, + }) + + const { result: trueResult } = await collectGenerator( + trueTool.stream(createMockContext({ name: 'trueTool', toolUseId: 'test-true', input: {} })) + ) + + expect(trueResult).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-true', + status: 'success', + content: [ + expect.objectContaining({ + type: 'textBlock', + text: 'true', + }), + ], + }) + + const falseTool = new FunctionTool({ + name: 'falseTool', + description: 'Returns false', + inputSchema: { type: 'object' }, + callback: (): boolean => false, + }) + + const { result: falseResult } = await collectGenerator( + falseTool.stream(createMockContext({ name: 'falseTool', toolUseId: 'test-false', input: {} })) + ) + + expect(falseResult).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-false', + status: 'success', + content: [ + expect.objectContaining({ + type: 'textBlock', + text: 'false', + }), + ], + }) + }) + + it('handles number return values as text content', async () => { + const tool = new FunctionTool({ + name: 'numberTool', + description: 'Returns number', + inputSchema: { type: 'object' }, + callback: (): number => 42, + }) + + const { result } = await collectGenerator( + tool.stream(createMockContext({ name: 'numberTool', toolUseId: 'test-number', input: {} })) + ) + + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-number', + status: 'success', + content: [ + expect.objectContaining({ + type: 'textBlock', + text: '42', + }), + ], + }) + + // Test negative number + const negativeTool = new FunctionTool({ + name: 'negativeTool', + description: 'Returns negative number', + inputSchema: { type: 'object' }, + callback: (): number => -3.14, + }) + + const { result: negativeResult } = await collectGenerator( + negativeTool.stream(createMockContext({ name: 'negativeTool', toolUseId: 'test-negative', input: {} })) + ) + + expect(negativeResult).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-negative', + status: 'success', + content: [ + expect.objectContaining({ + type: 'textBlock', + text: '-3.14', + }), + ], + }) + }) + + it('handles array return values as wrapped JSON content', async () => { + const tool = new FunctionTool({ + name: 'arrayTool', + description: 'Returns array', + inputSchema: { type: 'object' }, + callback: (): JSONValue[] => [1, 2, 3, { key: 'value' }], + }) + + const { result } = await collectGenerator( + tool.stream(createMockContext({ name: 'arrayTool', toolUseId: 'test-array', input: {} })) + ) + + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-array', + status: 'success', + content: [ + expect.objectContaining({ + type: 'jsonBlock', + json: { $value: [1, 2, 3, { key: 'value' }] }, + }), + ], + }) + }) + + it('deep copies objects to prevent mutation', async () => { + const original = { nested: { value: 'original' } } + const tool = new FunctionTool({ + name: 'copyTool', + description: 'Returns object', + inputSchema: { type: 'object' }, + callback: (): { nested: { value: string } } => original, + }) + + const { result } = await collectGenerator( + tool.stream(createMockContext({ name: 'copyTool', toolUseId: 'test-copy', input: {} })) + ) + + // Mutate the original object + original.nested.value = 'mutated' + + // Verify the result still has the original value + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-copy', + status: 'success', + content: [ + expect.objectContaining({ + type: 'jsonBlock', + json: { nested: { value: 'original' } }, + }), + ], + }) + }) + + it('deep copies arrays to prevent mutation', async () => { + const original = [{ value: 'original' }] + const tool = new FunctionTool({ + name: 'arrayCopyTool', + description: 'Returns array', + inputSchema: { type: 'object' }, + callback: (): JSONValue[] => original, + }) + + const { result } = await collectGenerator( + tool.stream(createMockContext({ name: 'arrayCopyTool', toolUseId: 'test-array-copy', input: {} })) + ) + + // Mutate the original array + original[0]!.value = 'mutated' + + // Verify the result still has the original value (wrapped in $value) + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-array-copy', + status: 'success', + content: [ + expect.objectContaining({ + type: 'jsonBlock', + json: { $value: [{ value: 'original' }] }, + }), + ], + }) + }) + }) + + describe('with promise callback', () => { + it('wraps resolved value in ToolResult', async () => { + const tool = new FunctionTool({ + name: 'promiseTool', + description: 'Returns promise', + inputSchema: { type: 'object', properties: { value: { type: 'number' } } }, + callback: async (input: unknown): Promise => { + const { value } = input as { value: number } + return value + 10 + }, + }) + + const toolUse = { + name: 'promiseTool', + toolUseId: 'test-promise-1', + input: { value: 5 }, + } + const context = createMockContext(toolUse) + + const { items: streamEvents, result } = await collectGenerator(tool.stream(context)) + + expect(streamEvents.length).toBe(0) + expect(result.toolUseId).toBe('test-promise-1') + expect(result.status).toBe('success') + expect(result.status).toBe('success') + }) + + it('can access ToolContext in promise', async () => { + const tool = new FunctionTool({ + name: 'contextTool', + description: 'Uses context', + inputSchema: { type: 'object' }, + callback: async (_input: unknown, context: ToolContext): Promise => { + return context.agent.appState.getAll() + }, + }) + + const toolUse = { + name: 'contextTool', + toolUseId: 'test-context', + input: {}, + } + const context = createMockContext(toolUse, { userId: 'user-123' }) + + const { items: streamEvents, result } = await collectGenerator(tool.stream(context)) + + expect(streamEvents.length).toBe(0) + expect(result.status).toBe('success') + }) + }) + + describe('with async generator callback', () => { + it('yields ToolStreamEvents then final ToolResult', async () => { + const tool = new FunctionTool({ + name: 'generatorTool', + description: 'Streams progress', + inputSchema: { type: 'object' }, + callback: async function* (): AsyncGenerator { + yield 'Starting...' + yield 'Processing...' + yield 'Complete!' + return 'Final result' + }, + }) + + const toolUse = { + name: 'generatorTool', + toolUseId: 'test-gen-1', + input: {}, + } + const context = createMockContext(toolUse) + + const { items: streamEvents, result } = await collectGenerator(tool.stream(context)) + + // Should have 3 stream events + expect(streamEvents.length).toBe(3) + + // Verify all stream events are as expected + expect(streamEvents).toEqual([ + { type: 'toolStreamEvent', data: 'Starting...' }, + { type: 'toolStreamEvent', data: 'Processing...' }, + { type: 'toolStreamEvent', data: 'Complete!' }, + ]) + + // Verify entire result object + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-gen-1', + status: 'success', + content: [ + expect.objectContaining({ + type: 'textBlock', + text: 'Final result', + }), + ], + }) + }) + + it('can yield objects as ToolStreamEvents', async () => { + const tool = new FunctionTool({ + name: 'objectGenTool', + description: 'Streams objects', + inputSchema: { type: 'object' }, + callback: async function* (): AsyncGenerator<{ progress: number; message: string }, string, unknown> { + yield { progress: 0.25, message: 'Quarter done' } + yield { progress: 0.5, message: 'Halfway done' } + yield { progress: 1.0, message: 'Complete' } + return 'Done' + }, + }) + + const toolUse = { + name: 'objectGenTool', + toolUseId: 'test-obj-gen', + input: {}, + } + const context = createMockContext(toolUse) + + const { items: streamEvents, result } = await collectGenerator(tool.stream(context)) + + expect(streamEvents.length).toBe(3) + + // Verify all stream events have data + for (const event of streamEvents) { + expect(event.type).toBe('toolStreamEvent') + expect(event.data).toBeDefined() + } + + // Verify final result + expect(result.status).toBe('success') + }) + }) + + describe('error handling', () => { + it('catches synchronous errors', async () => { + const tool = new FunctionTool({ + name: 'errorTool', + description: 'Throws error', + inputSchema: { type: 'object' }, + callback: (): never => { + throw new Error('Something went wrong') + }, + }) + + const toolUse = { + name: 'errorTool', + toolUseId: 'test-error-1', + input: {}, + } + const context = createMockContext(toolUse) + + const { items: streamEvents, result } = await collectGenerator(tool.stream(context)) + + expect(streamEvents.length).toBe(0) + expect(result.toolUseId).toBe('test-error-1') + expect(result.status).toBe('error') + expect(result.content.length).toBeGreaterThan(0) + expect(result.content[0]?.type).toBe('textBlock') + }) + + it('catches promise rejections', async () => { + const tool = new FunctionTool({ + name: 'rejectTool', + description: 'Rejects promise', + inputSchema: { type: 'object' }, + callback: async (): Promise => { + throw new Error('Promise rejected') + }, + }) + + const toolUse = { + name: 'rejectTool', + toolUseId: 'test-error-2', + input: {}, + } + const context = createMockContext(toolUse) + + const { items: streamEvents, result } = await collectGenerator(tool.stream(context)) + + expect(streamEvents.length).toBe(0) + expect(result.status).toBe('error') + }) + + it('captures Error object in ToolResult when callback throws Error', async () => { + const testError = new Error('Test error message') + const tool = new FunctionTool({ + name: 'errorObjectTool', + description: 'Throws Error object', + inputSchema: { type: 'object' }, + callback: (): never => { + throw testError + }, + }) + + const toolUse = { + name: 'errorObjectTool', + toolUseId: 'test-error-capture', + input: {}, + } + + const { result } = await collectGenerator(tool.stream(createMockContext(toolUse))) + + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-error-capture', + status: 'error', + content: [ + expect.objectContaining({ + type: 'textBlock', + text: 'Error: Test error message', + }), + ], + error: testError, + }) + }) + + it('wraps non-Error thrown values into Error object', async () => { + const tool = new FunctionTool({ + name: 'stringThrowTool', + description: 'Throws string', + inputSchema: { type: 'object' }, + callback: (): never => { + throw 'string error' + }, + }) + + const toolUse = { + name: 'stringThrowTool', + toolUseId: 'test-string-wrap', + input: {}, + } + + const { result } = await collectGenerator(tool.stream(createMockContext(toolUse))) + + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-string-wrap', + status: 'error', + content: [ + expect.objectContaining({ + type: 'textBlock', + text: 'Error: string error', + }), + ], + error: expect.any(Error), + }) + expect(result.error?.message).toBe('string error') + }) + + it('preserves custom Error subclass instances', async () => { + class CustomError extends Error { + constructor( + message: string, + public code: string + ) { + super(message) + this.name = 'CustomError' + } + } + + const customError = new CustomError('Custom error message', 'ERR_001') + const tool = new FunctionTool({ + name: 'customErrorTool', + description: 'Throws custom error', + inputSchema: { type: 'object' }, + callback: (): never => { + throw customError + }, + }) + + const toolUse = { + name: 'customErrorTool', + toolUseId: 'test-custom-error', + input: {}, + } + + const { result } = await collectGenerator(tool.stream(createMockContext(toolUse))) + + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-custom-error', + status: 'error', + content: [ + expect.objectContaining({ + type: 'textBlock', + text: 'Error: Custom error message', + }), + ], + error: customError, + }) + expect((result.error as CustomError).code).toBe('ERR_001') + }) + + it('preserves error stack traces', async () => { + const tool = new FunctionTool({ + name: 'stackTraceTool', + description: 'Throws error with stack trace', + inputSchema: { type: 'object' }, + callback: (): never => { + throw new Error('Error with stack') + }, + }) + + const toolUse = { + name: 'stackTraceTool', + toolUseId: 'test-stack-trace', + input: {}, + } + + const { result } = await collectGenerator(tool.stream(createMockContext(toolUse))) + + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-stack-trace', + status: 'error', + content: [ + expect.objectContaining({ + type: 'textBlock', + text: 'Error: Error with stack', + }), + ], + error: expect.any(Error), + }) + expect(result.error?.stack).toBeDefined() + expect(result.error?.stack).toContain('Error: Error with stack') + }) + + it('captures errors thrown in async generator callbacks', async () => { + const testError = new Error('Async generator error') + const tool = new FunctionTool({ + name: 'asyncGenErrorTool', + description: 'Async generator that throws', + inputSchema: { type: 'object' }, + callback: async function* (): AsyncGenerator { + yield 'Starting...' + throw testError + }, + }) + + const toolUse = { + name: 'asyncGenErrorTool', + toolUseId: 'test-async-gen-error', + input: {}, + } + + const context = tool.stream(createMockContext(toolUse)) + const { items: streamEvents, result } = await collectGenerator(context) + + // Should have one stream event before the error + expect(streamEvents.length).toBe(1) + expect(streamEvents[0]).toEqual({ type: 'toolStreamEvent', data: 'Starting...' }) + + // Final result should have error object + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-async-gen-error', + status: 'error', + content: [ + expect.objectContaining({ + type: 'textBlock', + text: 'Error: Async generator error', + }), + ], + error: testError, + }) + }) + + it('catches errors in async generators', async () => { + const tool = new FunctionTool({ + name: 'genErrorTool', + description: 'Generator throws', + inputSchema: { type: 'object' }, + callback: async function* (): AsyncGenerator { + yield 'Starting...' + throw new Error('Generator error') + }, + }) + + const toolUse = { + name: 'genErrorTool', + toolUseId: 'test-error-3', + input: {}, + } + const context = createMockContext(toolUse) + + const { items: streamEvents, result } = await collectGenerator(tool.stream(context)) + + // Should have one stream event before the error + expect(streamEvents.length).toBe(1) + expect(streamEvents[0]?.type).toBe('toolStreamEvent') + + // Final result should be error + expect(result.status).toBe('error') + }) + + it('handles non-Error thrown values', async () => { + const tool = new FunctionTool({ + name: 'stringErrorTool', + description: 'Throws string', + inputSchema: { type: 'object' }, + callback: (): never => { + throw 'String error' + }, + }) + + const toolUse = { + name: 'stringErrorTool', + toolUseId: 'test-error-4', + input: {}, + } + const context = createMockContext(toolUse) + + const { items: streamEvents, result } = await collectGenerator(tool.stream(context)) + + expect(streamEvents.length).toBe(0) + expect(result.status).toBe('error') + }) + + it('returns error for circular references', async () => { + const tool = new FunctionTool({ + name: 'circularTool', + description: 'Returns circular object', + inputSchema: { type: 'object' }, + callback: (): JSONValue => { + // Create circular reference + const obj: any = { a: 1 } + obj.self = obj + return obj + }, + }) + + const { result } = await collectGenerator( + tool.stream(createMockContext({ name: 'circularTool', toolUseId: 'test-circular', input: {} })) + ) + + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-circular', + status: 'error', + error: expect.any(Error), + content: [ + expect.objectContaining({ + type: 'textBlock', + text: expect.stringContaining('Error:'), + }), + ], + }) + }) + + it('silently drops non-serializable values (functions)', async () => { + const tool = new FunctionTool({ + name: 'functionTool', + description: 'Returns object with function', + inputSchema: { type: 'object' }, + callback: (): JSONValue => { + return { fn: () => {} } as any + }, + }) + + const { result } = await collectGenerator( + tool.stream(createMockContext({ name: 'functionTool', toolUseId: 'test-function', input: {} })) + ) + + // Functions are silently dropped during JSON serialization + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-function', + status: 'success', + content: [ + expect.objectContaining({ + type: 'jsonBlock', + json: {}, + }), + ], + }) + }) + }) + }) +}) + +describe('Tool interface backwards compatibility', () => { + it('maintains stable interface signature', () => { + const tool = new FunctionTool({ + name: 'testTool', + description: 'Test description', + inputSchema: { type: 'object' }, + callback: (): string => 'result', + }) + + // Verify interface properties exist + expect(tool).toHaveProperty('name') + expect(tool).toHaveProperty('description') + expect(tool).toHaveProperty('toolSpec') + expect(tool).toHaveProperty('stream') + + // Verify stream is a function + expect(typeof tool.stream).toBe('function') + }) + + it('stream method accepts correct parameter types', async () => { + const tool = new FunctionTool({ + name: 'testTool', + description: 'Test description', + inputSchema: { type: 'object' }, + callback: (input: unknown): JSONValue => input as JSONValue, + }) + const toolUse = { + name: 'testTool', + toolUseId: 'test-types', + input: { value: 123 }, + } + const context = createMockContext(toolUse) + + // This should compile and execute without type errors + const stream = tool.stream({ ...context, toolUse }) + expect(stream).toBeDefined() + expect(Symbol.asyncIterator in stream).toBe(true) + + // Consume the stream with helper + const { result } = await collectGenerator(stream) + + expect(result).toBeDefined() + expect(result.status).toBe('success') + }) +}) + +describe('ToolStreamEvent', () => { + describe('instantiation', () => { + it('creates instance with data', () => { + const event = new ToolStreamEvent({ + data: 'test data', + }) + + expect(event).toEqual({ + type: 'toolStreamEvent', + data: 'test data', + }) + }) + + it('creates instance without data', () => { + const event = new ToolStreamEvent({}) + + expect(event).toEqual({ + type: 'toolStreamEvent', + }) + }) + + it('creates instance with structured data', () => { + const structuredData = { + progress: 50, + message: 'halfway complete', + } + + const event = new ToolStreamEvent({ + data: structuredData, + }) + + expect(event).toEqual({ + type: 'toolStreamEvent', + data: structuredData, + }) + }) + }) +}) + +describe('instanceof checks', () => { + describe('FunctionTool', () => { + it('passes instanceof Tool check', () => { + const tool = new FunctionTool({ + name: 'testTool', + description: 'Test description', + inputSchema: { type: 'object' }, + callback: (): string => 'result', + }) + + expect(tool instanceof Tool).toBe(true) + }) + + it('can be used as type guard', () => { + const tool = new FunctionTool({ + name: 'testTool', + description: 'Test description', + inputSchema: { type: 'object' }, + callback: (): string => 'result', + }) + + // Type guard function + function isTool(value: unknown): value is Tool { + return value instanceof Tool + } + + expect(isTool(tool)).toBe(true) + expect(isTool({})).toBe(false) + expect(isTool(null)).toBe(false) + }) + }) +}) + +describe('optional inputSchema', () => { + describe('when inputSchema is undefined', () => { + it('creates tool with default empty object schema', () => { + const tool = new FunctionTool({ + name: 'noInputTool', + description: 'Tool that takes no input', + callback: () => 'result', + }) + + expect(tool.name).toBe('noInputTool') + expect(tool.description).toBe('Tool that takes no input') + expect(tool.toolSpec).toEqual({ + name: 'noInputTool', + description: 'Tool that takes no input', + inputSchema: { + type: 'object', + properties: {}, + additionalProperties: false, + }, + }) + }) + + it('executes successfully with empty input', async () => { + const tool = new FunctionTool({ + name: 'getStatus', + description: 'Gets system status', + callback: () => ({ status: 'operational' }), + }) + + const toolUse = { + name: 'getStatus', + toolUseId: 'test-no-input-1', + input: {}, + } + + const { result } = await collectGenerator(tool.stream(createMockContext(toolUse))) + + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-no-input-1', + status: 'success', + content: [ + expect.objectContaining({ + type: 'jsonBlock', + json: { status: 'operational' }, + }), + ], + }) + }) + + it('callback receives empty object when no schema provided', async () => { + let receivedInput: unknown + const tool = new FunctionTool({ + name: 'captureInput', + description: 'Captures the input', + callback: (input: unknown) => { + receivedInput = input + return 'captured' + }, + }) + + const toolUse = { + name: 'captureInput', + toolUseId: 'test-input-capture', + input: {}, + } + + await collectGenerator(tool.stream(createMockContext(toolUse))) + + expect(receivedInput).toEqual({}) + }) + + it('works with async callback', async () => { + const tool = new FunctionTool({ + name: 'asyncNoInput', + description: 'Async tool with no input', + callback: async () => { + return 'async result' + }, + }) + + const toolUse = { + name: 'asyncNoInput', + toolUseId: 'test-async-no-input', + input: {}, + } + + const { result } = await collectGenerator(tool.stream(createMockContext(toolUse))) + + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-async-no-input', + status: 'success', + content: [ + expect.objectContaining({ + type: 'textBlock', + text: 'async result', + }), + ], + }) + }) + + it('works with async generator callback', async () => { + const tool = new FunctionTool({ + name: 'streamNoInput', + description: 'Streaming tool with no input', + callback: async function* () { + yield 'Starting...' + yield 'Processing...' + return 'Complete!' + }, + }) + + const toolUse = { + name: 'streamNoInput', + toolUseId: 'test-stream-no-input', + input: {}, + } + + const { items: streamEvents, result } = await collectGenerator(tool.stream(createMockContext(toolUse))) + + expect(streamEvents).toEqual([ + { type: 'toolStreamEvent', data: 'Starting...' }, + { type: 'toolStreamEvent', data: 'Processing...' }, + ]) + + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-stream-no-input', + status: 'success', + content: [ + expect.objectContaining({ + type: 'textBlock', + text: 'Complete!', + }), + ], + }) + }) + }) +}) diff --git a/strands-ts/src/tools/__tests__/zod-tool.test-d.ts b/strands-ts/src/tools/__tests__/zod-tool.test-d.ts new file mode 100644 index 0000000000..5f4938bdcf --- /dev/null +++ b/strands-ts/src/tools/__tests__/zod-tool.test-d.ts @@ -0,0 +1,299 @@ +import { describe, it, expectTypeOf } from 'vitest' +import { z } from 'zod' +import { tool } from '../tool-factory.js' + +describe('zod-tool type tests', () => { + describe('invoke return type matches callback return type', () => { + it('should return string when callback returns string', () => { + const stringTool = tool({ + name: 'stringTool', + inputSchema: z.object({ value: z.string() }), + callback: (input) => input.value, + }) + + expectTypeOf(stringTool.invoke).returns.resolves.toEqualTypeOf() + }) + + it('should return number when callback returns number', () => { + const numberTool = tool({ + name: 'numberTool', + inputSchema: z.object({ a: z.number(), b: z.number() }), + callback: (input) => input.a + input.b, + }) + + expectTypeOf(numberTool.invoke).returns.resolves.toEqualTypeOf() + }) + + it('should return boolean when callback returns boolean', () => { + const booleanTool = tool({ + name: 'booleanTool', + inputSchema: z.object({ value: z.number() }), + callback: (input) => input.value > 0, + }) + + expectTypeOf(booleanTool.invoke).returns.resolves.toEqualTypeOf() + }) + + it('should return object when callback returns object', () => { + const objectTool = tool({ + name: 'objectTool', + inputSchema: z.object({ name: z.string(), age: z.number() }), + callback: (input) => ({ greeting: `Hello ${input.name}`, isAdult: input.age >= 18 }), + }) + + expectTypeOf(objectTool.invoke).returns.resolves.toEqualTypeOf<{ + greeting: string + isAdult: boolean + }>() + }) + + it('should return array when callback returns array', () => { + const arrayTool = tool({ + name: 'arrayTool', + inputSchema: z.object({ count: z.number() }), + callback: (input) => Array.from({ length: input.count }, (_, i) => i + 1), + }) + + expectTypeOf(arrayTool.invoke).returns.resolves.toEqualTypeOf() + }) + + it('should return null when callback returns null', () => { + const nullTool = tool({ + name: 'nullTool', + inputSchema: z.object({ value: z.string() }), + callback: () => null, + }) + + expectTypeOf(nullTool.invoke).returns.resolves.toEqualTypeOf() + }) + }) + + describe('async callback return types', () => { + it('should return string when async callback returns string', () => { + const asyncStringTool = tool({ + name: 'asyncStringTool', + inputSchema: z.object({ value: z.string() }), + callback: async (input): Promise => `Result: ${input.value}`, + }) + + expectTypeOf(asyncStringTool.invoke).returns.resolves.toEqualTypeOf() + }) + + it('should return number when async callback returns number', () => { + const asyncNumberTool = tool({ + name: 'asyncNumberTool', + inputSchema: z.object({ value: z.number() }), + callback: async (input) => input.value * 2, + }) + + expectTypeOf(asyncNumberTool.invoke).returns.resolves.toEqualTypeOf() + }) + + it('should return complex object when async callback returns complex object', () => { + const asyncComplexTool = tool({ + name: 'asyncComplexTool', + inputSchema: z.object({ id: z.string() }), + callback: async (input) => ({ + id: input.id, + timestamp: Date.now(), + metadata: { processed: true }, + }), + }) + + expectTypeOf(asyncComplexTool.invoke).returns.resolves.toEqualTypeOf<{ + id: string + timestamp: number + metadata: { processed: true } + }>() + }) + }) + + describe('async generator callback return types', () => { + it('should return the final return value from async generator', () => { + const generatorTool = tool({ + name: 'generatorTool', + inputSchema: z.object({ count: z.number() }), + callback: async function* (input) { + for (let i = 1; i <= input.count; i++) { + yield `Step ${i}` + } + return input.count + }, + }) + + expectTypeOf(generatorTool.invoke).returns.resolves.toEqualTypeOf() + }) + + it('should return string when async generator returns string', () => { + const generatorStringTool = tool({ + name: 'generatorStringTool', + inputSchema: z.object({ message: z.string() }), + callback: async function* (input): AsyncGenerator { + yield 'Processing...' + yield 'Almost done...' + return `Completed: ${input.message}` + }, + }) + + expectTypeOf(generatorStringTool.invoke).returns.resolves.toEqualTypeOf() + }) + + it('should return object when async generator returns object', () => { + const generatorObjectTool = tool({ + name: 'generatorObjectTool', + inputSchema: z.object({ data: z.array(z.string()) }), + callback: async function* (input) { + for (const item of input.data) { + yield `Processing ${item}` + } + return { processed: input.data.length, success: true } + }, + }) + + expectTypeOf(generatorObjectTool.invoke).returns.resolves.toEqualTypeOf<{ + processed: number + success: true + }>() + }) + }) + + describe('union return types', () => { + it('should handle union return types correctly', () => { + const unionTool = tool({ + name: 'unionTool', + inputSchema: z.object({ returnType: z.enum(['string', 'number']) }), + callback: (input): string | number => { + if (input.returnType === 'string') { + return 'hello' + } else { + return 42 + } + }, + }) + + expectTypeOf(unionTool.invoke).returns.resolves.toEqualTypeOf() + }) + + it('should handle conditional return types', () => { + const conditionalTool = tool({ + name: 'conditionalTool', + inputSchema: z.object({ includeMetadata: z.boolean(), value: z.string() }), + callback: (input) => { + if (input.includeMetadata) { + return { value: input.value, metadata: { timestamp: Date.now() } } + } else { + return input.value + } + }, + }) + + expectTypeOf(conditionalTool.invoke).returns.resolves.toEqualTypeOf< + string | { value: string; metadata: { timestamp: number } } + >() + }) + }) + + describe('input type validation', () => { + it('should enforce correct input types', () => { + const typedTool = tool({ + name: 'typedTool', + inputSchema: z.object({ + name: z.string(), + age: z.number(), + active: z.boolean(), + }), + callback: (input) => input.name, + }) + + // Should accept correct input type + expectTypeOf(typedTool.invoke).parameter(0).toEqualTypeOf<{ + name: string + age: number + active: boolean + }>() + }) + + it('should handle optional fields in input', () => { + const optionalTool = tool({ + name: 'optionalTool', + inputSchema: z.object({ + required: z.string(), + optional: z.string().optional(), + }), + callback: (input) => input.required, + }) + + expectTypeOf(optionalTool.invoke).parameter(0).toEqualTypeOf<{ + required: string + optional?: string | undefined + }>() + }) + + it('should handle complex nested input types', () => { + const nestedTool = tool({ + name: 'nestedTool', + inputSchema: z.object({ + user: z.object({ + name: z.string(), + profile: z.object({ + age: z.number(), + preferences: z.array(z.string()), + }), + }), + metadata: z.object({ + created: z.number(), + tags: z.array(z.string()), + }), + }), + callback: (input) => input.user.name, + }) + + expectTypeOf(nestedTool.invoke).parameter(0).toEqualTypeOf<{ + user: { + name: string + profile: { + age: number + preferences: string[] + } + } + metadata: { + created: number + tags: string[] + } + }>() + }) + }) + + describe('generic type constraints', () => { + it('should maintain type safety with explicit generic parameters', () => { + // Test with explicit return type + const explicitTool = tool, string>({ + name: 'explicitTool', + inputSchema: z.object({ value: z.string() }), + callback: (input) => input.value, + }) + + expectTypeOf(explicitTool.invoke).returns.resolves.toEqualTypeOf() + }) + + it('should work with complex generic constraints', () => { + type CustomResult = { + id: string + data: number[] + success: boolean + } + + const customTool = tool, CustomResult>({ + name: 'customTool', + inputSchema: z.object({ id: z.string(), count: z.number() }), + callback: (input): CustomResult => ({ + id: input.id, + data: Array.from({ length: input.count }, (_, i) => i), + success: true, + }), + }) + + expectTypeOf(customTool.invoke).returns.resolves.toEqualTypeOf() + }) + }) +}) diff --git a/strands-ts/src/tools/__tests__/zod-tool.test.ts b/strands-ts/src/tools/__tests__/zod-tool.test.ts new file mode 100644 index 0000000000..c126dba877 --- /dev/null +++ b/strands-ts/src/tools/__tests__/zod-tool.test.ts @@ -0,0 +1,643 @@ +import { describe, expect, it, vi } from 'vitest' +import { z } from 'zod' +import { tool } from '../tool-factory.js' +import { Tool } from '../tool.js' +import { createMockContext } from '../../__fixtures__/tool-helpers.js' +import { collectGenerator } from '../../__fixtures__/model-test-helpers.js' +import type { JSONValue } from '../../types/json.js' +import type { ToolContext } from '../tool.js' + +/** + * Helper to create a mock ToolContext with just input for zod tool tests. + */ +function createContext(input: JSONValue): ToolContext { + return createMockContext({ + name: 'testTool', + toolUseId: 'test-123', + input, + }) +} + +describe('tool', () => { + describe('tool creation and properties', () => { + it('creates tool with correct properties', () => { + const myTool = tool({ + name: 'testTool', + description: 'Test description', + inputSchema: z.object({ value: z.string() }), + callback: (input) => input.value, + }) + + expect(myTool.name).toBe('testTool') + expect(myTool.description).toBe('Test description') + expect(myTool.toolSpec).toEqual({ + name: 'testTool', + description: 'Test description', + inputSchema: { + type: 'object', + properties: { + value: { type: 'string' }, + }, + required: ['value'], + additionalProperties: false, + }, + }) + }) + + it('handles optional description', () => { + const myTool = tool({ + name: 'testTool', + inputSchema: z.object({ value: z.string() }), + callback: (input) => input.value, + }) + + expect(myTool.name).toBe('testTool') + expect(myTool.description).toBe('') + }) + }) + + describe('invoke() method', () => { + describe('basic return types', () => { + it('handles synchronous callback', async () => { + const myTool = tool({ + name: 'sync', + description: 'Synchronous tool', + inputSchema: z.object({ a: z.number(), b: z.number() }), + callback: (input) => input.a + input.b, + }) + + const result = await myTool.invoke({ a: 5, b: 3 }) + expect(result).toBe(8) + }) + + it('handles promise callback', async () => { + const myTool = tool({ + name: 'async', + description: 'Async tool', + inputSchema: z.object({ value: z.string() }), + callback: async (input) => `Result: ${input.value}`, + }) + + const result = await myTool.invoke({ value: 'test' }) + expect(result).toBe('Result: test') + }) + + it('handles async generator callback', async () => { + const myTool = tool({ + name: 'generator', + description: 'Generator tool', + inputSchema: z.object({ count: z.number() }), + callback: async function* (input) { + for (let i = 1; i <= input.count; i++) { + yield i + } + return 0 + }, + }) + + const result = await myTool.invoke({ count: 3 }) + expect(result).toBe(0) + }) + }) + + describe('validation', () => { + it('throws on invalid input', async () => { + const myTool = tool({ + name: 'validator', + description: 'Validates input', + inputSchema: z.object({ age: z.number().min(0).max(120) }), + callback: (input) => input.age, + }) + + await expect(myTool.invoke({ age: -1 })).rejects.toThrow() + await expect(myTool.invoke({ age: 150 })).rejects.toThrow() + }) + + it('validates required fields', async () => { + const myTool = tool({ + name: 'required', + description: 'Required fields', + inputSchema: z.object({ + name: z.string(), + email: z.string().email(), + }), + callback: (input) => `${input.name}: ${input.email}`, + }) + + await expect(myTool.invoke({ name: 'John' } as never)).rejects.toThrow() + await expect(myTool.invoke({ email: 'invalid-email' } as never)).rejects.toThrow() + }) + }) + + describe('context handling', () => { + it('passes context to callback', async () => { + const callback = vi.fn((input, context) => { + expect(context).toBeDefined() + return input.value + }) + + const myTool = tool({ + name: 'context', + description: 'Uses context', + inputSchema: z.object({ value: z.string() }), + callback, + }) + + const mockContext = createContext({ value: 'test' }) + await myTool.invoke({ value: 'test' }, mockContext) + expect(callback).toHaveBeenCalled() + }) + }) + }) + + describe('stream() method', () => { + describe('basic return types', () => { + it('streams synchronous callback result', async () => { + const myTool = tool({ + name: 'sync', + description: 'Synchronous tool', + inputSchema: z.object({ value: z.string() }), + callback: (input) => input.value, + }) + + const context = createContext({ value: 'hello' }) + const { items: events, result } = await collectGenerator(myTool.stream(context)) + + expect(events).toHaveLength(0) // No stream events for sync + expect(result.status).toBe('success') + expect(result.content).toHaveLength(1) + expect(result.content[0]).toEqual(expect.objectContaining({ type: 'textBlock', text: 'hello' })) + }) + + it('streams promise callback result', async () => { + const myTool = tool({ + name: 'async', + description: 'Async tool', + inputSchema: z.object({ value: z.number() }), + callback: async (input) => input.value * 2, + }) + + const context = createContext({ value: 21 }) + const { items: events, result } = await collectGenerator(myTool.stream(context)) + + expect(events).toHaveLength(0) // No stream events for promise + expect(result.status).toBe('success') + expect(result.content).toHaveLength(1) + expect(result.content[0]).toEqual(expect.objectContaining({ type: 'textBlock', text: '42' })) + }) + + it('streams async generator callback results', async () => { + const myTool = tool({ + name: 'generator', + description: 'Generator tool', + inputSchema: z.object({ count: z.number() }), + callback: async function* (input) { + for (let i = 1; i <= input.count; i++) { + yield `Step ${i}` + } + return 0 + }, + }) + + const context = createContext({ count: 3 }) + const { items: events, result } = await collectGenerator(myTool.stream(context)) + + expect(events).toHaveLength(3) + const eventData = events.map((e) => e.data) + expect(eventData).toEqual(['Step 1', 'Step 2', 'Step 3']) + expect(result.status).toBe('success') + }) + }) + + describe('validation', () => { + it('returns error result on validation failure', async () => { + const myTool = tool({ + name: 'validator', + description: 'Validates input', + inputSchema: z.object({ age: z.number().min(0) }), + callback: (input) => input.age, + }) + + const context = createContext({ age: -5 }) + const { items: events, result } = await collectGenerator(myTool.stream(context)) + + expect(events).toHaveLength(0) + expect(result.status).toBe('error') + expect(result.content.length).toBeGreaterThan(0) + const firstContent = result.content[0] + if (firstContent && firstContent.type == 'textBlock') { + expect(firstContent.text).toContain('age') + } + }) + + it('returns error result on missing required fields', async () => { + const myTool = tool({ + name: 'required', + description: 'Required fields', + inputSchema: z.object({ + name: z.string(), + value: z.number(), + }), + callback: (input) => `${input.name}: ${input.value}`, + }) + + const context = createContext({ name: 'test' }) + const { items: events, result } = await collectGenerator(myTool.stream(context)) + + expect(events).toHaveLength(0) + expect(result.status).toBe('error') + }) + }) + + describe('error handling', () => { + it('catches callback errors and returns error result', async () => { + const myTool = tool({ + name: 'error', + description: 'Throws error', + inputSchema: z.object({ value: z.string() }), + callback: () => { + throw new Error('Callback error') + }, + }) + + const context = createContext({ value: 'test' }) + const { items: events, result } = await collectGenerator(myTool.stream(context)) + + expect(events).toHaveLength(0) + expect(result.status).toBe('error') + expect(result.content.length).toBeGreaterThan(0) + const firstContent = result.content[0] + if (firstContent && firstContent.type == 'textBlock') { + expect(firstContent.text).toBe('Error: Callback error') + } + }) + + it('catches async callback errors', async () => { + const myTool = tool({ + name: 'asyncError', + description: 'Throws async error', + inputSchema: z.object({ value: z.string() }), + callback: async () => { + throw new Error('Async error') + }, + }) + + const context = createContext({ value: 'test' }) + const { items: events, result } = await collectGenerator(myTool.stream(context)) + + expect(events).toHaveLength(0) + expect(result.status).toBe('error') + expect(result.content.length).toBeGreaterThan(0) + const firstContent = result.content[0] + if (firstContent && firstContent.type == 'textBlock') { + expect(firstContent.text).toBe('Error: Async error') + } + }) + }) + }) + + describe('complex scenarios', () => { + it('handles nested object schemas', async () => { + const myTool = tool({ + name: 'nested', + description: 'Nested objects', + inputSchema: z.object({ + user: z.object({ + name: z.string(), + age: z.number(), + }), + metadata: z.object({ + timestamp: z.number(), + }), + }), + callback: (input) => `${input.user.name} (${input.user.age})`, + }) + + const result = await myTool.invoke({ + user: { name: 'Alice', age: 30 }, + metadata: { timestamp: Date.now() }, + }) + expect(result).toBe('Alice (30)') + }) + + it('handles enum schemas', async () => { + const myTool = tool({ + name: 'calculator', + description: 'Basic calculator', + inputSchema: z.object({ + operation: z.enum(['add', 'subtract', 'multiply', 'divide']), + a: z.number(), + b: z.number(), + }), + callback: (input) => { + switch (input.operation) { + case 'add': + return input.a + input.b + case 'subtract': + return input.a - input.b + case 'multiply': + return input.a * input.b + case 'divide': + return input.a / input.b + } + }, + }) + + expect(await myTool.invoke({ operation: 'add', a: 5, b: 3 })).toBe(8) + expect(await myTool.invoke({ operation: 'multiply', a: 4, b: 7 })).toBe(28) + }) + + it('handles optional fields', async () => { + const myTool = tool({ + name: 'greeting', + description: 'Generates greeting', + inputSchema: z.object({ + name: z.string(), + title: z.string().optional(), + }), + callback: (input) => { + return input.title ? `${input.title} ${input.name}` : input.name + }, + }) + + expect(await myTool.invoke({ name: 'Smith' })).toBe('Smith') + expect(await myTool.invoke({ name: 'Smith', title: 'Dr.' })).toBe('Dr. Smith') + }) + + it('handles array schemas', async () => { + const myTool = tool({ + name: 'sum', + description: 'Sums numbers', + inputSchema: z.object({ + numbers: z.array(z.number()), + }), + callback: (input) => input.numbers.reduce((a, b) => a + b, 0), + }) + + expect(await myTool.invoke({ numbers: [1, 2, 3, 4, 5] })).toBe(15) + }) + }) + + describe('JSON schema generation', () => { + it('generates valid JSON schema from Zod schema', () => { + const myTool = tool({ + name: 'test', + description: 'Test tool', + inputSchema: z.object({ + name: z.string(), + age: z.number(), + email: z.string().email(), + }), + callback: () => 'result', + }) + + const schema = myTool.toolSpec.inputSchema + expect(schema).toEqual({ + type: 'object', + additionalProperties: false, + properties: { + age: { + type: 'number', + }, + email: { + format: 'email', + pattern: + "^(?!\\.)(?!.*\\.\\.)([A-Za-z0-9_'+\\-\\.]*)[A-Za-z0-9_+-]@([A-Za-z0-9][A-Za-z0-9\\-]*\\.)+[A-Za-z]{2,}$", + type: 'string', + }, + name: { + type: 'string', + }, + }, + required: ['name', 'age', 'email'], + }) + }) + }) + + describe('instanceof checks', () => { + it('passes instanceof Tool check and has InvokableTool methods', () => { + const myTool = tool({ + name: 'testTool', + description: 'Test description', + inputSchema: z.object({ value: z.string() }), + callback: (input) => input.value, + }) + + // Verify instanceof Tool + expect(myTool instanceof Tool).toBe(true) + + // Verify InvokableTool interface methods are present + expect(typeof myTool.invoke).toBe('function') + expect(typeof myTool.stream).toBe('function') + + // Verify can be used as type guard (various types) + expect(myTool instanceof Tool).toBe(true) + expect({} instanceof Tool).toBe(false) + // TypeScript doesn't allow null/undefined in instanceof, verify they're not Tool instances differently + expect((null as unknown) instanceof Tool).toBe(false) + }) + }) + + describe('optional inputSchema', () => { + describe('when inputSchema is undefined', () => { + it('creates tool with default empty object schema', () => { + const myTool = tool({ + name: 'noInputTool', + description: 'Tool with no input', + callback: () => 'result', + }) + + expect(myTool.name).toBe('noInputTool') + expect(myTool.description).toBe('Tool with no input') + expect(myTool.toolSpec).toEqual({ + name: 'noInputTool', + description: 'Tool with no input', + inputSchema: { + type: 'object', + properties: {}, + additionalProperties: false, + }, + }) + }) + + it('invoke() works with empty object', async () => { + const myTool = tool({ + name: 'getPreferences', + description: 'Gets user preferences', + callback: () => ({ theme: 'dark', language: 'en' }), + }) + + const result = await myTool.invoke({}) + expect(result).toEqual({ theme: 'dark', language: 'en' }) + }) + + it('stream() works with empty input', async () => { + const myTool = tool({ + name: 'getStatus', + description: 'Gets system status', + callback: () => ({ status: 'operational', uptime: 99.9 }), + }) + + const { result } = await collectGenerator(myTool.stream(createContext({}))) + + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-123', + status: 'success', + content: [ + expect.objectContaining({ + type: 'jsonBlock', + json: { status: 'operational', uptime: 99.9 }, + }), + ], + }) + }) + + it('callback receives empty object when no schema', async () => { + let capturedInput: unknown + const myTool = tool({ + name: 'captureInput', + description: 'Captures input', + callback: (input) => { + capturedInput = input + return 'captured' + }, + }) + + await myTool.invoke({}) + expect(capturedInput).toEqual({}) + }) + + it('works with async callback', async () => { + const myTool = tool({ + name: 'asyncNoInput', + description: 'Async tool', + callback: async () => { + return 'async result' + }, + }) + + const result = await myTool.invoke({}) + expect(result).toBe('async result') + }) + + it('works with async generator callback', async () => { + const myTool = tool({ + name: 'streamNoInput', + description: 'Streaming tool', + callback: async function* () { + yield 'Starting...' + yield 'Processing...' + return 'Complete!' + }, + }) + + const result = await myTool.invoke({}) + expect(result).toBe('Complete!') + }) + }) + + describe('when inputSchema is z.void()', () => { + it('creates tool with default empty object schema', () => { + const myTool = tool({ + name: 'voidInputTool', + description: 'Tool with z.void() input', + inputSchema: z.void(), + callback: () => 'result', + }) + + expect(myTool.name).toBe('voidInputTool') + expect(myTool.description).toBe('Tool with z.void() input') + expect(myTool.toolSpec).toEqual({ + name: 'voidInputTool', + description: 'Tool with z.void() input', + inputSchema: { + type: 'object', + properties: {}, + additionalProperties: false, + }, + }) + }) + + it('invoke() works with empty object', async () => { + const myTool = tool({ + name: 'refreshCache', + description: 'Refreshes the cache', + inputSchema: z.void(), + callback: () => ({ refreshed: true, timestamp: Date.now() }), + }) + + const result = await myTool.invoke({} as never) + expect(result).toHaveProperty('refreshed', true) + expect(result).toHaveProperty('timestamp') + }) + + it('stream() works with empty input', async () => { + const myTool = tool({ + name: 'pingServer', + description: 'Pings the server', + inputSchema: z.void(), + callback: () => ({ pong: true }), + }) + + const { result } = await collectGenerator(myTool.stream(createContext({}))) + + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-123', + status: 'success', + content: [ + expect.objectContaining({ + type: 'jsonBlock', + json: { pong: true }, + }), + ], + }) + }) + + it('works with async generator callback', async () => { + const myTool = tool({ + name: 'streamVoidInput', + description: 'Streaming with void input', + inputSchema: z.void(), + callback: async function* () { + yield 'Step 1' + yield 'Step 2' + return 'Done' + }, + }) + + const { items: streamEvents, result } = await collectGenerator(myTool.stream(createContext({}))) + + expect(streamEvents).toEqual([ + { type: 'toolStreamEvent', data: 'Step 1' }, + { type: 'toolStreamEvent', data: 'Step 2' }, + ]) + + expect(result).toEqual({ + type: 'toolResultBlock', + toolUseId: 'test-123', + status: 'success', + content: [ + expect.objectContaining({ + type: 'textBlock', + text: 'Done', + }), + ], + }) + }) + + it('does not throw Zod conversion errors', () => { + // This test verifies that z.void() doesn't cause errors during tool creation + expect(() => { + tool({ + name: 'voidTool', + description: 'Tool with void', + inputSchema: z.void(), + callback: () => 'ok', + }) + }).not.toThrow() + }) + }) + }) +}) diff --git a/strands-ts/src/tools/function-tool.ts b/strands-ts/src/tools/function-tool.ts new file mode 100644 index 0000000000..ebd2569bf0 --- /dev/null +++ b/strands-ts/src/tools/function-tool.ts @@ -0,0 +1,376 @@ +import { createErrorResult, Tool } from './tool.js' +import type { InvokableTool, ToolContext } from './tool.js' +import { ToolStreamEvent } from './tool.js' +import type { ToolSpec } from './types.js' +import type { JSONSchema, JSONValue } from '../types/json.js' +import { deepCopy } from '../types/json.js' +import { + JsonBlock, + TextBlock, + ToolResultBlock, + toolResultContentFromData, + type ToolResultContent, + type ToolResultContentData, +} from '../types/messages.js' +import { DocumentBlock, ImageBlock, VideoBlock } from '../types/media.js' +import { InterruptError } from '../interrupt.js' + +/** + * Callback function for FunctionTool implementations. + * The callback can return values in multiple ways, and FunctionTool handles the conversion to ToolResultBlock. + * + * @param input - The input parameters conforming to the tool's inputSchema + * @param toolContext - The tool execution context with invocation state + * @returns Can return: + * - AsyncGenerator: Each yielded value becomes a ToolStreamEvent, final value wrapped in ToolResultBlock + * - Promise: Resolved value is wrapped in ToolResultBlock + * - Synchronous value: Value is wrapped in ToolResultBlock + * - If an error is thrown, it's handled and returned as an error ToolResultBlock + * + * @example + * ```typescript + * // Async generator example + * async function* calculator(input: unknown, context: ToolContext) { + * yield 'Calculating...' + * const result = input.a + input.b + * yield `Result: ${result}` + * return result + * } + * + * // Promise example + * async function fetchData(input: unknown, context: ToolContext) { + * const response = await fetch(input.url) + * return await response.json() + * } + * + * // Synchronous example + * function multiply(input: unknown, context: ToolContext) { + * return input.a * input.b + * } + * ``` + */ +export type FunctionToolCallback = ( + input: unknown, + toolContext: ToolContext +) => AsyncGenerator | Promise | JSONValue + +/** + * Configuration options for creating a FunctionTool. + */ +export interface FunctionToolConfig { + /** The unique name of the tool */ + name: string + /** Human-readable description of the tool's purpose */ + description: string + /** JSON Schema defining the expected input structure. If omitted, defaults to an empty object schema. */ + inputSchema?: JSONSchema + /** Function that implements the tool logic */ + callback: FunctionToolCallback +} + +/** + * A Tool implementation that wraps a callback function and handles all ToolResultBlock conversion. + * + * FunctionTool allows creating tools from existing functions without needing to manually + * handle ToolResultBlock formatting or error handling. It supports multiple callback patterns: + * - Async generators for streaming responses + * - Promises for async operations + * - Synchronous functions for immediate results + * + * All return values are automatically wrapped in ToolResultBlock, and errors are caught and + * returned as error ToolResultBlocks. + * + * @example + * ```typescript + * // Create a tool with streaming + * const streamingTool = new FunctionTool({ + * name: 'processor', + * description: 'Processes data with progress updates', + * inputSchema: { type: 'object', properties: { data: { type: 'string' } } }, + * callback: async function* (input: any) { + * yield 'Starting processing...' + * // Do some work + * yield 'Halfway done...' + * // More work + * return 'Processing complete!' + * } + * }) + * ``` + */ +export class FunctionTool extends Tool implements InvokableTool { + /** + * The unique name of the tool. + */ + readonly name: string + + /** + * Human-readable description of what the tool does. + */ + readonly description: string + + /** + * OpenAPI JSON specification for the tool. + */ + readonly toolSpec: ToolSpec + + /** + * The callback function that implements the tool's logic. + */ + private readonly _callback: FunctionToolCallback + + /** + * Creates a new FunctionTool instance. + * + * @param config - Configuration object for the tool + * + * @example + * ```typescript + * // Tool with input schema + * const greetTool = new FunctionTool({ + * name: 'greeter', + * description: 'Greets a person by name', + * inputSchema: { + * type: 'object', + * properties: { name: { type: 'string' } }, + * required: ['name'] + * }, + * callback: (input: any) => `Hello, ${input.name}!` + * }) + * + * // Tool without input (no parameters) + * const statusTool = new FunctionTool({ + * name: 'getStatus', + * description: 'Gets system status', + * callback: () => ({ status: 'operational' }) + * }) + * ``` + */ + constructor(config: FunctionToolConfig) { + super() + this.name = config.name + this.description = config.description + + // Use provided schema or default empty object schema + const inputSchema = config.inputSchema ?? { + type: 'object', + properties: {}, + additionalProperties: false, + } + + this.toolSpec = { + name: config.name, + description: config.description, + inputSchema: inputSchema, + } + this._callback = config.callback + } + + /** + * Executes the tool with streaming support. + * Handles all callback patterns (async generator, promise, sync) and converts results to ToolResultBlock. + * + * @param toolContext - Context information including the tool use request and invocation state + * @returns Async generator that yields ToolStreamEvents and returns a ToolResultBlock + */ + async *stream(toolContext: ToolContext): AsyncGenerator { + const { toolUse } = toolContext + + try { + const result = this._callback(toolUse.input, toolContext) + + // Check if result is an async generator + if (result && typeof result === 'object' && Symbol.asyncIterator in result) { + // Handle async generator: yield each value as ToolStreamEvent, wrap final value in ToolResultBlock + const generator = result as AsyncGenerator + + // Iterate through all yielded values + let iterResult = await generator.next() + + while (!iterResult.done) { + // Each yielded value becomes a ToolStreamEvent + yield new ToolStreamEvent({ + data: iterResult.value, + }) + iterResult = await generator.next() + } + + // The generator's return value (when done = true) is wrapped in ToolResultBlock + return this._wrapInToolResult(iterResult.value, toolUse.toolUseId) + } else if (result instanceof Promise) { + // Handle promise: await and wrap in ToolResultBlock + const value = await result + return this._wrapInToolResult(value, toolUse.toolUseId) + } else { + // Handle synchronous value: wrap in ToolResultBlock + return this._wrapInToolResult(result, toolUse.toolUseId) + } + } catch (error) { + // Re-throw InterruptError to allow interrupt handling in agent loop + if (error instanceof InterruptError) { + throw error + } + // Handle any other errors and yield as error ToolResultBlock + return createErrorResult(error, toolUse.toolUseId) + } + } + + /** + * Invokes the tool directly with raw input and returns the unwrapped result. + * + * Unlike stream(), this method: + * - Returns the raw result (not wrapped in ToolResult) + * - Consumes async generators and returns the generator's return value + * - Lets errors throw naturally (not wrapped in error ToolResult) + * + * @param input - The input parameters for the tool + * @param context - Optional tool execution context + * @returns The unwrapped result + */ + async invoke(input: unknown, context?: ToolContext): Promise { + const result = this._callback(input, context as ToolContext) + + if (result && typeof result === 'object' && Symbol.asyncIterator in result) { + const generator = result as AsyncGenerator + let iterResult = await generator.next() + while (!iterResult.done) { + iterResult = await generator.next() + } + return iterResult.value + } + + return (await result) as JSONValue + } + + /** + * Wraps a value in a ToolResultBlock with success status. + * + * Due to AWS Bedrock limitations (only accepts objects as JSON content), the following + * rules are applied: + * - Content blocks (TextBlock, JsonBlock, ImageBlock, VideoBlock, DocumentBlock) → passed through directly + * - Arrays of content blocks → used directly as content array + * - Strings → TextBlock + * - Numbers, Booleans → TextBlock (converted to string) + * - null, undefined → TextBlock (special string representation) + * - Objects → JsonBlock (with deep copy) + * - Arrays (non-content blocks) → JsonBlock wrapped in \{ $value: array \} (with deep copy) + * + * @param value - The value to wrap (can be any type) + * @param toolUseId - The tool use ID for the ToolResultBlock + * @returns A ToolResultBlock containing the value + */ + private _wrapInToolResult(value: unknown, toolUseId: string): ToolResultBlock { + try { + // Handle media blocks - pass through directly + if (value instanceof DocumentBlock || value instanceof ImageBlock || value instanceof VideoBlock) { + return new ToolResultBlock({ + toolUseId, + status: 'success', + content: [value], + }) + } + + // Handle null with special string representation as text content + if (value === null) { + return new ToolResultBlock({ + toolUseId, + status: 'success', + content: [new TextBlock('')], + }) + } + + // Handle undefined with special string representation as text content + if (value === undefined) { + return new ToolResultBlock({ + toolUseId, + status: 'success', + content: [new TextBlock('')], + }) + } + + // Handle content blocks - class instances or plain data objects + const contentBlock = this._asToolResultContent(value) + if (contentBlock) { + return new ToolResultBlock({ + toolUseId, + status: 'success', + content: [contentBlock], + }) + } + + // Handle primitives (strings, numbers, booleans) as text content + // Bedrock doesn't accept primitives as JSON content, so we convert all to strings + if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') { + return new ToolResultBlock({ + toolUseId, + status: 'success', + content: [new TextBlock(String(value))], + }) + } + + // Handle arrays + if (Array.isArray(value)) { + // Check if array contains only content blocks (class instances or plain data objects) + if (value.length > 0) { + const converted = value.map((item) => this._asToolResultContent(item)) + if (converted.every((item): item is ToolResultContent => item !== undefined)) { + return new ToolResultBlock({ + toolUseId, + status: 'success', + content: converted, + }) + } + } + + // Otherwise wrap in object { $value: array } + const copiedValue = deepCopy(value) + return new ToolResultBlock({ + toolUseId, + status: 'success', + content: [new JsonBlock({ json: { $value: copiedValue } })], + }) + } + + // Handle objects as JSON content with deep copy + const copiedValue = deepCopy(value) + return new ToolResultBlock({ + toolUseId, + status: 'success', + content: [new JsonBlock({ json: copiedValue })], + }) + } catch (error) { + // If deep copy fails (circular references, non-serializable values), return error result + return createErrorResult(error, toolUseId) + } + } + + /** + * Converts a value to a ToolResultContent instance if it matches a known content type. + * Accepts both class instances and plain data objects. + * + * @param value - Value to check and convert + * @returns ToolResultContent instance, or undefined if not a recognized content type + */ + private _asToolResultContent(value: unknown): ToolResultContent | undefined { + if (typeof value !== 'object') return undefined + + // Class instances — pass through + if ( + value instanceof TextBlock || + value instanceof JsonBlock || + value instanceof ImageBlock || + value instanceof VideoBlock || + value instanceof DocumentBlock + ) { + return value + } + + // Plain data objects — require exactly one key to match the discriminated + // union shape; multi-key objects fall through to JsonBlock instead. + try { + if (Object.keys(value as object).length !== 1) return undefined + return toolResultContentFromData(value as ToolResultContentData) + } catch { + return undefined + } + } +} diff --git a/strands-ts/src/tools/mcp-tool.ts b/strands-ts/src/tools/mcp-tool.ts new file mode 100644 index 0000000000..289edf4bb0 --- /dev/null +++ b/strands-ts/src/tools/mcp-tool.ts @@ -0,0 +1,210 @@ +import { McpError, ErrorCode, UrlElicitationRequiredError } from '@modelcontextprotocol/sdk/types.js' + +import { createErrorResult, Tool, type ToolContext, type ToolStreamGenerator } from './tool.js' +import type { ToolSpec } from './types.js' +import type { JSONSchema, JSONValue } from '../types/json.js' +import { JsonBlock, TextBlock, ToolResultBlock, type ToolResultContent } from '../types/messages.js' +import { ImageBlock, decodeBase64 } from '../types/media.js' +import { toMediaFormat, IMAGE_FORMATS, type ImageFormat } from '../mime.js' +import type { McpClient } from '../mcp.js' +import { logger } from '../logging/logger.js' + +export interface McpToolConfig { + name: string + description: string + inputSchema: JSONSchema + client: McpClient +} + +/** + * A Tool implementation that proxies calls to a remote MCP server. + * + * Unlike FunctionTool, which wraps local logic, McpTool delegates execution + * to the connected McpClient and translates the SDK's response format + * directly into ToolResultBlocks. + */ +export class McpTool extends Tool { + readonly name: string + readonly description: string + readonly toolSpec: ToolSpec + private readonly mcpClient: McpClient + + constructor(config: McpToolConfig) { + super() + this.name = config.name + this.description = config.description + this.toolSpec = { + name: config.name, + description: config.description, + inputSchema: config.inputSchema, + } + this.mcpClient = config.client + } + + // eslint-disable-next-line require-yield + async *stream(toolContext: ToolContext): ToolStreamGenerator { + const { toolUseId, input } = toolContext.toolUse + + try { + const rawResult: unknown = await this.mcpClient.callTool(this, input as JSONValue, { + signal: toolContext.agent.cancelSignal, + }) + + if (!this._isMcpToolResult(rawResult)) { + throw new Error('Invalid tool result from MCP Client: missing content array') + } + + const content: ToolResultContent[] = [] + + for (const item of rawResult.content) { + content.push(this._mapMcpContent(item)) + } + + if (content.length === 0) { + content.push(new TextBlock('Tool execution completed successfully with no output.')) + } + + return new ToolResultBlock({ + toolUseId, + status: rawResult.isError ? 'error' : 'success', + content, + }) + } catch (error) { + if ( + error instanceof UrlElicitationRequiredError || + (error instanceof McpError && error.code === ErrorCode.UrlElicitationRequired) + ) { + const elicitations = + error instanceof UrlElicitationRequiredError + ? error.elicitations + : (error.data as Record | undefined)?.elicitations + if (Array.isArray(elicitations) && elicitations.length > 0) { + return new ToolResultBlock({ + toolUseId, + status: 'error', + content: [ + new TextBlock(`MCP Elicitation required: [${String(error)}] with data ${JSON.stringify(elicitations)}`), + ], + }) + } + } + return createErrorResult(error, toolUseId) + } + } + + /** + * Maps a single MCP content item to an SDK ToolResultContent block. + * + * @param item - MCP content item from tool result + * @returns Mapped content block + */ + private _mapMcpContent(item: unknown): ToolResultContent { + if (!item || typeof item !== 'object') { + return new JsonBlock({ json: item as JSONValue }) + } + + const record = item as Record + + switch (record.type) { + case 'text': + if (typeof record.text === 'string') { + return new TextBlock(record.text) + } + return new JsonBlock({ json: item as JSONValue }) + + case 'image': + return this._mapMcpImageContent(record) + + case 'resource': + return this._mapMcpEmbeddedResource(record) + + default: + return new JsonBlock({ json: item as JSONValue }) + } + } + + /** + * Maps an MCP image content item to an ImageBlock. + * + * @param record - MCP image content with data (base64) and mimeType + * @returns ImageBlock or TextBlock fallback if format is unsupported + */ + private _mapMcpImageContent(record: Record): ToolResultContent { + const data = record.data + const mimeType = record.mimeType + + if (typeof data !== 'string' || typeof mimeType !== 'string') { + logger.warn('content_type= | mcp image content missing data or mimeType, falling back to json') + return new JsonBlock({ json: record as JSONValue }) + } + + const format = toMediaFormat(mimeType) + if (!format || !this._isImageFormat(format)) { + logger.warn(`mime_type=<${mimeType}> | unsupported mcp image mime type, falling back to json`) + return new JsonBlock({ json: record as JSONValue }) + } + + return new ImageBlock({ + format, + source: { bytes: decodeBase64(data) }, + }) + } + + /** + * Maps an MCP embedded resource to an SDK content block. + * Text resources become TextBlock, blob resources with image MIME types become ImageBlock. + * + * @param record - MCP embedded resource content + * @returns Mapped content block or undefined if unsupported + */ + private _mapMcpEmbeddedResource(record: Record): ToolResultContent { + const resource = record.resource + if (!resource || typeof resource !== 'object') { + return new JsonBlock({ json: record as JSONValue }) + } + + const res = resource as Record + + // Text resource + if (typeof res.text === 'string') { + return new TextBlock(res.text) + } + + // Blob resource + if (typeof res.blob === 'string' && typeof res.mimeType === 'string') { + const format = toMediaFormat(res.mimeType) + if (format && this._isImageFormat(format)) { + return new ImageBlock({ + format, + source: { bytes: decodeBase64(res.blob) }, + }) + } + // Non-image blob: fall back to json + logger.warn(`mime_type=<${res.mimeType}> | unsupported mcp resource blob mime type, falling back to json`) + } + + return new JsonBlock({ json: record as JSONValue }) + } + + /** + * Type Guard: Checks if value matches the expected MCP SDK result shape. + * \{ content: unknown[]; isError?: boolean \} + */ + private _isMcpToolResult(value: unknown): value is { content: unknown[]; isError?: boolean } { + if (typeof value !== 'object' || value === null) { + return false + } + + // Safe cast to generic record to check properties + const record = value as Record + + return Array.isArray(record.content) + } + + /** + * Type Guard: Checks if a media format is a supported image format. + */ + private _isImageFormat(format: string): format is ImageFormat { + return (IMAGE_FORMATS as readonly string[]).includes(format) + } +} diff --git a/strands-ts/src/tools/noop-tool.ts b/strands-ts/src/tools/noop-tool.ts new file mode 100644 index 0000000000..03a3d7f59b --- /dev/null +++ b/strands-ts/src/tools/noop-tool.ts @@ -0,0 +1,18 @@ +/** + * Shared tool helpers and constants. + */ + +import type { ToolSpec } from './types.js' + +/** + * A no-op tool spec that instructs the model to ignore it completely. + * + * Some model providers (e.g. Bedrock) require a tool configuration when messages + * contain tool use/result blocks. This noop tool can be injected to satisfy that + * requirement without affecting model behavior. + */ +export const NOOP_TOOL_SPEC: ToolSpec = { + name: 'noop', + description: 'This is a fake tool that MUST be completely ignored.', + inputSchema: { type: 'object', properties: {} }, +} diff --git a/strands-ts/src/tools/structured-output-tool.ts b/strands-ts/src/tools/structured-output-tool.ts new file mode 100644 index 0000000000..e5971f2d47 --- /dev/null +++ b/strands-ts/src/tools/structured-output-tool.ts @@ -0,0 +1,95 @@ +import { z } from 'zod' +import { Tool, type ToolContext, type ToolStreamGenerator } from './tool.js' +import type { ToolSpec } from './types.js' +import { JsonBlock, TextBlock, ToolResultBlock } from '../types/messages.js' +import type { JSONValue } from '../types/json.js' +import { zodSchemaToJsonSchema } from './zod-utils.js' + +/** Tool name used for structured output validation. */ +export const STRUCTURED_OUTPUT_TOOL_NAME = 'strands_structured_output' + +/** + * Tool that validates LLM output against a Zod schema. + * Provides validation feedback to the LLM for retry on failures. + */ +export class StructuredOutputTool extends Tool { + private _schema: z.ZodSchema + private _toolSpec: ToolSpec + + /** + * Creates a new StructuredOutputTool. + * + * @param schema - The Zod schema to validate against + */ + constructor(schema: z.ZodSchema) { + super() + this._schema = schema + this._toolSpec = this._buildSpec() + } + + /** @returns The tool name. */ + get name(): string { + return this._toolSpec.name + } + + /** @returns The tool description. */ + get description(): string { + return this._toolSpec.description + } + + /** @returns The full tool specification. */ + get toolSpec(): ToolSpec { + return this._toolSpec + } + + /** + * Validates input against the schema. + * On success, returns a ToolResultBlock with the validated JSON. + * On failure, returns formatted validation errors for LLM retry. + * + * @param toolContext - The tool execution context + * @returns Generator that returns a ToolResultBlock + */ + // Validation is synchronous, so no streaming events are yielded + // eslint-disable-next-line require-yield + async *stream(toolContext: ToolContext): ToolStreamGenerator { + const { toolUse } = toolContext + + try { + const validated = this._schema.parse(toolUse.input) as JSONValue + + return new ToolResultBlock({ + toolUseId: toolUse.toolUseId, + status: 'success', + content: [new JsonBlock({ json: validated })], + }) + } catch (error) { + const validationError = error instanceof Error ? error : new Error(String(error)) + + return new ToolResultBlock({ + toolUseId: toolUse.toolUseId, + status: 'error', + content: [new TextBlock(validationError.message)], + error: validationError, + }) + } + } + + /** + * Builds the tool specification from the schema. + * + * @returns Tool specification with name, description, and input schema + */ + private _buildSpec(): ToolSpec { + const instruction = + 'This tool MUST only be invoked as the last and final tool before returning the completed result to the caller.' + + return { + name: STRUCTURED_OUTPUT_TOOL_NAME, + description: this._schema.description + ? `${instruction}\n${this._schema.description}` + : instruction, + inputSchema: zodSchemaToJsonSchema(this._schema), + } + } +} diff --git a/strands-ts/src/tools/tool-factory.ts b/strands-ts/src/tools/tool-factory.ts new file mode 100644 index 0000000000..5cd24e28f3 --- /dev/null +++ b/strands-ts/src/tools/tool-factory.ts @@ -0,0 +1,82 @@ +import type { InvokableTool } from './tool.js' +import { FunctionTool } from './function-tool.js' +import type { FunctionToolConfig } from './function-tool.js' +import type { JSONValue } from '../types/json.js' +import { z } from 'zod' +import { ZodTool, type ZodToolConfig } from './zod-tool.js' + +/** + * Checks whether a value is a Zod schema type. + * + * @param value - The value to check + * @returns True if the value is a Zod schema + */ +function isZodType(value: unknown): value is z.ZodType { + return value instanceof z.ZodType +} + +/** + * Creates an InvokableTool from a Zod schema and callback function. + * + * @typeParam TInput - Zod schema type for input validation + * @typeParam TReturn - Return type of the callback function + * @param config - Tool configuration with Zod schema + * @returns An InvokableTool with typed input and output + */ +export function tool( + config: ZodToolConfig +): InvokableTool, TReturn> + +/** + * Creates an InvokableTool from a JSON schema and callback function. + * + * @param config - Tool configuration with optional JSON schema + * @returns An InvokableTool with unknown input + */ +export function tool(config: FunctionToolConfig): InvokableTool + +/** + * Creates an InvokableTool from either a Zod schema or JSON schema configuration. + * + * When a Zod schema is provided as `inputSchema`, input is validated at runtime and + * the callback receives typed input. When a JSON schema (or no schema) is provided, + * the callback receives `unknown` input with no runtime validation. + * + * @example + * ```typescript + * import { tool } from '@strands-agents/sdk' + * import { z } from 'zod' + * + * // With Zod schema (typed + validated) + * const calculator = tool({ + * name: 'calculator', + * description: 'Adds two numbers', + * inputSchema: z.object({ a: z.number(), b: z.number() }), + * callback: (input) => input.a + input.b, + * }) + * + * // With JSON schema (untyped, no validation) + * const greeter = tool({ + * name: 'greeter', + * description: 'Greets a person', + * inputSchema: { + * type: 'object', + * properties: { name: { type: 'string' } }, + * required: ['name'], + * }, + * callback: (input) => `Hello, ${(input as { name: string }).name}!`, + * }) + * ``` + * + * @param config - Tool configuration + * @returns An InvokableTool that implements the Tool interface with invoke() method + */ +export function tool( + config: ZodToolConfig | FunctionToolConfig +): InvokableTool { + if (config.inputSchema && isZodType(config.inputSchema)) { + return new ZodTool(config as ZodToolConfig) + } + + return new FunctionTool(config as FunctionToolConfig) +} diff --git a/strands-ts/src/tools/tool.ts b/strands-ts/src/tools/tool.ts new file mode 100644 index 0000000000..f7959342c3 --- /dev/null +++ b/strands-ts/src/tools/tool.ts @@ -0,0 +1,218 @@ +import type { ToolSpec, ToolUse } from './types.js' +import { TextBlock, ToolResultBlock } from '../types/messages.js' +import type { InvocationState, LocalAgent } from '../types/agent.js' +import { normalizeError } from '../errors.js' +import type { Interruptible } from '../interrupt.js' + +export type { ToolSpec } from './types.js' + +/** + * Context provided to tool implementations during execution. + * Contains framework-level state and information from the agent invocation. + */ +export interface ToolContext extends Interruptible { + /** + * The tool use request that triggered this tool execution. + * Contains the tool name, toolUseId, and input parameters. + */ + toolUse: ToolUse + + /** + * The agent instance that is executing this tool. + * Provides access to agent state, conversation history, and cancellation state. + */ + agent: LocalAgent + + /** + * Per-invocation state shared across hooks and tools for the current + * agent invocation. Mutable — read and write freely; changes are visible to + * subsequent hooks, tools, and on {@link AgentResult.invocationState}. + * + * Distinct from `agent.appState`: `invocationState` is ephemeral and accepts + * arbitrary values, while `appState` is durable, JSON-serializable, and + * deep-copied on read/write. + */ + invocationState: InvocationState +} + +/** + * Data for a tool stream event. + */ +export interface ToolStreamEventData { + /** + * Discriminator for tool stream events. + */ + type: 'toolStreamEvent' + + /** + * Caller-provided data for the progress update. + * Can be any type of data the tool wants to report. + */ + data?: unknown +} + +/** + * Event yielded during tool execution to report streaming progress. + * Tools can yield zero or more of these events before returning the final ToolResult. + * + * @example + * ```typescript + * const streamEvent = new ToolStreamEvent({ + * data: 'Processing step 1...' + * }) + * + * // Or with structured data + * const streamEvent = new ToolStreamEvent({ + * data: { progress: 50, message: 'Halfway complete' } + * }) + * ``` + */ +export class ToolStreamEvent implements ToolStreamEventData { + /** + * Discriminator for tool stream events. + */ + readonly type = 'toolStreamEvent' as const + + /** + * Caller-provided data for the progress update. + * Can be any type of data the tool wants to report. + */ + readonly data?: unknown + + constructor(eventData: { data?: unknown }) { + if (eventData.data !== undefined) { + this.data = eventData.data + } + } +} + +/** + * Type alias for the async generator returned by tool stream methods. + * Yields ToolStreamEvents during execution and returns a ToolResultBlock. + */ +export type ToolStreamGenerator = AsyncGenerator + +/** + * Interface for tool implementations. + * Tools are used by agents to interact with their environment and perform specific actions. + * + * The Tool interface provides a streaming execution model where tools can yield + * progress events during execution before returning a final result. + * + * Most implementations should use FunctionTool rather than implementing this interface directly. + */ +export abstract class Tool { + /** + * The unique name of the tool. + * This MUST match the name in the toolSpec. + */ + abstract name: string + /** + * Human-readable description of what the tool does. + * This helps the model understand when to use the tool. + * + * This MUST match the description in the toolSpec.description. + */ + abstract description: string + /** + * OpenAPI JSON specification for the tool. + * Defines the tool's name, description, and input schema. + */ + abstract toolSpec: ToolSpec + + /** + * Executes the tool with streaming support. + * Yields zero or more ToolStreamEvents during execution, then returns + * exactly one ToolResultBlock as the final value. + * + * @param toolContext - Context information including the tool use request and invocation state + * @returns Async generator that yields ToolStreamEvents and returns a ToolResultBlock + * + * @example + * ```typescript + * const context = { + * toolUse: { + * name: 'calculator', + * toolUseId: 'calc-123', + * input: { operation: 'add', a: 5, b: 3 } + * }, + * } + * + * // The return value is only accessible via explicit .next() calls + * const generator = tool.stream(context) + * for await (const event of generator) { + * // Only yields are captured here + * console.log('Progress:', event.data) + * } + * // Or manually handle the return value: + * let result = await generator.next() + * while (!result.done) { + * console.log('Progress:', result.value.data) + * result = await generator.next() + * } + * console.log('Final result:', result.value.status) + * ``` + */ + abstract stream(toolContext: ToolContext): ToolStreamGenerator +} + +/** + * Extended tool interface that supports direct invocation with type-safe input and output. + * This interface is useful for testing and standalone tool execution. + * + * @typeParam TInput - Type for the tool's input parameters + * @typeParam TReturn - Type for the tool's return value + */ +export interface InvokableTool extends Tool { + /** + * Invokes the tool directly with type-safe input and returns the unwrapped result. + * + * Unlike stream(), this method: + * - Returns the raw result (not wrapped in ToolResult) + * - Consumes async generators and returns only the final value + * - Lets errors throw naturally (not wrapped in error ToolResult) + * + * @param input - The input parameters for the tool + * @param context - Optional tool execution context + * @returns The unwrapped result + */ + invoke(input: TInput, context?: ToolContext): Promise +} + +/** + * Creates an error ToolResultBlock from an error object. + * Ensures all errors are normalized to Error objects and includes the original error + * in the ToolResultBlock for inspection by hooks, error handlers, and agent loop. + * + * TODO: Implement consistent logging format as defined in #30 + * This error should be logged to the caller using the established logging pattern. + * + * @param error - The error that occurred (can be Error object or any thrown value) + * @param toolUseId - The tool use ID for the ToolResultBlock + * @returns A ToolResultBlock with error status, error message content, and original error object + */ +export function createErrorResult(error: unknown, toolUseId: string): ToolResultBlock { + // Ensure error is an Error object (wrap non-Error values) + const errorObject = normalizeError(error) + + return new ToolResultBlock({ + toolUseId, + status: 'error', + content: [new TextBlock(`Error: ${errorObject.message}`)], + error: errorObject, + }) +} + +const TOOL_NAME_PATTERN = /^[a-zA-Z0-9_-]+$/ +const TOOL_NAME_MAX_LENGTH = 64 + +/** + * Returns `true` when `name` satisfies the provider-accepted tool name format: + * non-empty, 1–64 characters, and only letters, digits, underscores, or hyphens. + * + * @param name - The tool name to validate + * @returns `true` if the name is valid, `false` otherwise + */ +export function isValidToolName(name: string): boolean { + return name.length > 0 && name.length <= TOOL_NAME_MAX_LENGTH && TOOL_NAME_PATTERN.test(name) +} diff --git a/strands-ts/src/tools/types.ts b/strands-ts/src/tools/types.ts new file mode 100644 index 0000000000..bcbd1fe4ef --- /dev/null +++ b/strands-ts/src/tools/types.ts @@ -0,0 +1,62 @@ +import type { JSONSchema, JSONValue } from '../types/json.js' + +/** + * Status of a tool execution. + * Indicates whether the tool executed successfully or encountered an error. + */ +export type ToolResultStatus = 'success' | 'error' + +/** + * Specification for a tool that can be used by the model. + * Defines the tool's name, description, and input schema. + */ +export interface ToolSpec { + /** + * The unique name of the tool. + */ + name: string + + /** + * A description of what the tool does. + * This helps the model understand when to use the tool. + */ + description: string + + /** + * JSON Schema defining the expected input structure for the tool. + * If omitted, defaults to an empty object schema allowing no input parameters. + */ + inputSchema?: JSONSchema +} + +/** + * Represents a tool usage request from the model. + * The model generates this when it wants to use a tool. + */ +export interface ToolUse { + /** + * The name of the tool to execute. + */ + name: string + + /** + * Unique identifier for this tool use instance. + * Used to match tool results back to their requests. + */ + toolUseId: string + + /** + * The input parameters for the tool. + * Must be JSON-serializable. + */ + input: JSONValue +} + +/** + * Specifies how the model should choose which tool to use. + * + * - `{ auto: {} }` - Let the model decide whether to use a tool + * - `{ any: {} }` - Force the model to use one of the available tools + * - `{ tool: { name: 'name' } }` - Force the model to use a specific tool + */ +export type ToolChoice = { auto: Record } | { any: Record } | { tool: { name: string } } diff --git a/strands-ts/src/tools/zod-tool.ts b/strands-ts/src/tools/zod-tool.ts new file mode 100644 index 0000000000..aa011d545b --- /dev/null +++ b/strands-ts/src/tools/zod-tool.ts @@ -0,0 +1,178 @@ +import type { InvokableTool, ToolContext, ToolStreamGenerator } from './tool.js' +import { Tool } from './tool.js' +import type { ToolSpec } from './types.js' +import type { JSONSchema, JSONValue } from '../types/json.js' +import { FunctionTool } from './function-tool.js' +import { z, ZodVoid } from 'zod' +import { zodSchemaToJsonSchema } from './zod-utils.js' + +/** + * Helper type to infer input type from Zod schema or default to never. + */ +type ZodInferred = TInput extends z.ZodType ? z.infer : never + +/** + * Configuration for creating a Zod-based tool. + * + * @typeParam TInput - Zod schema type for input validation + * @typeParam TReturn - Return type of the callback function + */ +export interface ZodToolConfig { + /** The name of the tool */ + name: string + + /** A description of what the tool does (optional) */ + description?: string + + /** + * Zod schema for input validation and JSON schema generation. + * If omitted or z.void(), the tool takes no input parameters. + */ + inputSchema?: TInput + + /** + * Callback function that implements the tool's functionality. + * + * @param input - Validated input matching the Zod schema + * @param context - Optional execution context + * @returns The result (can be a value, Promise, or AsyncGenerator) + */ + callback: ( + input: ZodInferred, + context?: ToolContext + ) => AsyncGenerator | Promise | TReturn +} + +/** + * Zod-based tool implementation. + * Extends Tool abstract class and implements InvokableTool interface. + */ +export class ZodTool + extends Tool + implements InvokableTool, TReturn> +{ + /** + * Internal FunctionTool for delegating stream operations. + */ + private readonly _functionTool: FunctionTool + + /** + * Zod schema for input validation. + * Note: undefined is normalized to z.void() in constructor, so this is always defined. + */ + private readonly _inputSchema: z.ZodType + + /** + * User callback function. + */ + private readonly _callback: ( + input: ZodInferred, + context?: ToolContext + ) => AsyncGenerator | Promise | TReturn + + constructor(config: ZodToolConfig) { + super() + const { name, description = '', inputSchema, callback } = config + + // Normalize undefined to z.void() to simplify logic throughout + this._inputSchema = inputSchema ?? z.void() + this._callback = callback + + let generatedSchema: JSONSchema + + // Handle z.void() - use default empty object schema + if (this._inputSchema instanceof ZodVoid) { + generatedSchema = { + type: 'object', + properties: {}, + additionalProperties: false, + } + } else { + generatedSchema = zodSchemaToJsonSchema(this._inputSchema) + } + + // Create a FunctionTool with a validation wrapper + this._functionTool = new FunctionTool({ + name, + description, + inputSchema: generatedSchema, + callback: ( + input: unknown, + toolContext: ToolContext + ): AsyncGenerator | Promise | JSONValue => { + // Only validate if schema is not z.void() (after normalization, it's never undefined) + const validatedInput = this._inputSchema instanceof ZodVoid ? input : this._inputSchema.parse(input) + // Execute user callback with validated input + return callback(validatedInput as ZodInferred, toolContext) as + | AsyncGenerator + | Promise + | JSONValue + }, + }) + } + + /** + * The unique name of the tool. + */ + get name(): string { + return this._functionTool.name + } + + /** + * Human-readable description of what the tool does. + */ + get description(): string { + return this._functionTool.description + } + + /** + * OpenAPI JSON specification for the tool. + */ + get toolSpec(): ToolSpec { + return this._functionTool.toolSpec + } + + /** + * Executes the tool with streaming support. + * Delegates to internal FunctionTool implementation. + * + * @param toolContext - Context information including the tool use request and invocation state + * @returns Async generator that yields ToolStreamEvents and returns a ToolResultBlock + */ + stream(toolContext: ToolContext): ToolStreamGenerator { + return this._functionTool.stream(toolContext) + } + + /** + * Invokes the tool directly with type-safe input and returns the unwrapped result. + * + * Unlike stream(), this method: + * - Returns the raw result (not wrapped in ToolResult) + * - Consumes async generators and returns only the final value + * - Lets errors throw naturally (not wrapped in error ToolResult) + * + * @param input - The input parameters for the tool + * @param context - Optional tool execution context + * @returns The unwrapped result + */ + async invoke(input: ZodInferred, context?: ToolContext): Promise { + // Only validate if schema is not z.void() (after normalization, it's never undefined) + const validatedInput = this._inputSchema instanceof ZodVoid ? input : this._inputSchema.parse(input) + + // Execute callback with validated input + const result = this._callback(validatedInput as ZodInferred, context) + + // Handle different return types + if (result && typeof result === 'object' && Symbol.asyncIterator in result) { + const generator = result as AsyncGenerator + let iterResult = await generator.next() + while (!iterResult.done) { + iterResult = await generator.next() + } + return iterResult.value + } else { + // Regular value or Promise - return directly + return await result + } + } +} diff --git a/strands-ts/src/tools/zod-utils.ts b/strands-ts/src/tools/zod-utils.ts new file mode 100644 index 0000000000..29306f8d66 --- /dev/null +++ b/strands-ts/src/tools/zod-utils.ts @@ -0,0 +1,15 @@ +import { z } from 'zod' +import type { JSONSchema } from '../types/json.js' + +/** + * Converts a Zod schema to JSON Schema format. + * Strips the $schema property to reduce token usage. + * + * @param schema - The Zod schema to convert + * @returns JSON Schema representation + */ +export function zodSchemaToJsonSchema(schema: z.ZodSchema): JSONSchema { + const result = z.toJSONSchema(schema) as JSONSchema & { $schema?: string } + const { $schema: _$schema, ...jsonSchema } = result + return jsonSchema as JSONSchema +} diff --git a/strands-ts/src/tsconfig.json b/strands-ts/src/tsconfig.json new file mode 100644 index 0000000000..4eb37fee05 --- /dev/null +++ b/strands-ts/src/tsconfig.json @@ -0,0 +1,3 @@ +{ + "extends": "../tsconfig.base.json" +} diff --git a/strands-ts/src/types/__tests__/agent.test.ts b/strands-ts/src/types/__tests__/agent.test.ts new file mode 100644 index 0000000000..c580f49138 --- /dev/null +++ b/strands-ts/src/types/__tests__/agent.test.ts @@ -0,0 +1,571 @@ +import { describe, it, expect } from 'vitest' +import { AgentResult } from '../agent.js' +import { AgentMetrics } from '../../telemetry/meter.js' +import { AgentTrace } from '../../telemetry/tracer.js' +import { Message } from '../messages.js' +import { TextBlock, ReasoningBlock, ToolUseBlock, ToolResultBlock, CachePointBlock } from '../messages.js' +import { CitationsBlock } from '../citations.js' +import { Interrupt } from '../../interrupt.js' + +describe('AgentResult', () => { + describe('toString', () => { + describe('when content is empty', () => { + it('returns empty string', () => { + const message = new Message({ + role: 'assistant', + content: [], + }) + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + metrics: new AgentMetrics(), + invocationState: {}, + }) + + expect(result.toString()).toBe('') + }) + }) + + describe('when content has single TextBlock', () => { + it('returns the text content', () => { + const message = new Message({ + role: 'assistant', + content: [new TextBlock('Hello, world!')], + }) + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + metrics: new AgentMetrics(), + invocationState: {}, + }) + + expect(result.toString()).toBe('Hello, world!') + }) + }) + + describe('when content has multiple TextBlocks', () => { + it('returns all text joined with newlines', () => { + const message = new Message({ + role: 'assistant', + content: [new TextBlock('First line'), new TextBlock('Second line'), new TextBlock('Third line')], + }) + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + metrics: new AgentMetrics(), + invocationState: {}, + }) + + expect(result.toString()).toBe('First line\nSecond line\nThird line') + }) + }) + + describe('when content has ReasoningBlock with text', () => { + it('returns the reasoning text with prefix', () => { + const message = new Message({ + role: 'assistant', + content: [new ReasoningBlock({ text: 'Let me think about this...' })], + }) + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + metrics: new AgentMetrics(), + invocationState: {}, + }) + + expect(result.toString()).toBe('💭 Reasoning:\n Let me think about this...') + }) + }) + + describe('when content has ReasoningBlock without text', () => { + it('returns empty string (reasoning block is skipped)', () => { + const message = new Message({ + role: 'assistant', + content: [new ReasoningBlock({ signature: 'abc123' })], + }) + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + metrics: new AgentMetrics(), + invocationState: {}, + }) + + expect(result.toString()).toBe('') + }) + }) + + describe('when content has mixed TextBlock and ReasoningBlock', () => { + it('returns all text joined with newlines', () => { + const message = new Message({ + role: 'assistant', + content: [ + new TextBlock('Here is my response.'), + new ReasoningBlock({ text: 'I reasoned carefully.' }), + new TextBlock('Additional context.'), + ], + }) + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + metrics: new AgentMetrics(), + invocationState: {}, + }) + + expect(result.toString()).toBe( + 'Here is my response.\n💭 Reasoning:\n I reasoned carefully.\nAdditional context.' + ) + }) + }) + + describe('when content has only non-text blocks', () => { + it('returns empty string', () => { + const message = new Message({ + role: 'assistant', + content: [ + new ToolUseBlock({ name: 'calc', toolUseId: 'id-1', input: { a: 1, b: 2 } }), + new ToolResultBlock({ + toolUseId: 'id-1', + status: 'success', + content: [new TextBlock('3')], + }), + new CachePointBlock({ cacheType: 'default' }), + ], + }) + + const result = new AgentResult({ + stopReason: 'toolUse', + lastMessage: message, + metrics: new AgentMetrics(), + invocationState: {}, + }) + + expect(result.toString()).toBe('') + }) + }) + + describe('when content has mixed text and non-text blocks', () => { + it('returns only text from TextBlock and ReasoningBlock', () => { + const message = new Message({ + role: 'assistant', + content: [ + new TextBlock('Before tool'), + new ToolUseBlock({ name: 'calc', toolUseId: 'id-1', input: { a: 1, b: 2 } }), + new ReasoningBlock({ text: 'Thinking...' }), + new CachePointBlock({ cacheType: 'default' }), + new TextBlock('After tool'), + ], + }) + + const result = new AgentResult({ + stopReason: 'toolUse', + lastMessage: message, + metrics: new AgentMetrics(), + invocationState: {}, + }) + + expect(result.toString()).toBe('Before tool\n💭 Reasoning:\n Thinking...\nAfter tool') + }) + }) + + describe('when interrupts are present', () => { + it('returns JSON-stringified interrupts, taking priority over text content', () => { + const message = new Message({ + role: 'assistant', + content: [new TextBlock('ignored')], + }) + + const interrupt = new Interrupt({ id: 'i-1', name: 'confirm', reason: 'ok?' }) + + const result = new AgentResult({ + stopReason: 'interrupt', + lastMessage: message, + metrics: new AgentMetrics(), + invocationState: {}, + interrupts: [interrupt], + }) + + expect(result.toString()).toBe(JSON.stringify([interrupt])) + }) + + it('falls through when interrupts array is empty', () => { + const message = new Message({ + role: 'assistant', + content: [new TextBlock('Hello')], + }) + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + metrics: new AgentMetrics(), + invocationState: {}, + interrupts: [], + }) + + expect(result.toString()).toBe('Hello') + }) + }) + + describe('when structuredOutput is present', () => { + it('returns JSON-stringified structured output, taking priority over text content', () => { + const message = new Message({ + role: 'assistant', + content: [new TextBlock('ignored')], + }) + + const structuredOutput = { answer: 42, note: 'hello' } + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + metrics: new AgentMetrics(), + invocationState: {}, + structuredOutput, + }) + + expect(result.toString()).toBe(JSON.stringify(structuredOutput)) + }) + }) + + describe('when interrupts and structuredOutput are both present', () => { + it('returns interrupts, taking priority over structuredOutput', () => { + const message = new Message({ + role: 'assistant', + content: [new TextBlock('ignored')], + }) + + const interrupt = new Interrupt({ id: 'i-1', name: 'confirm' }) + const structuredOutput = { answer: 42 } + + const result = new AgentResult({ + stopReason: 'interrupt', + lastMessage: message, + metrics: new AgentMetrics(), + invocationState: {}, + interrupts: [interrupt], + structuredOutput, + }) + + expect(result.toString()).toBe(JSON.stringify([interrupt])) + }) + }) + + describe('when content has CitationsBlock', () => { + it('concatenates generated content text from citations', () => { + const message = new Message({ + role: 'assistant', + content: [ + new TextBlock('Here is a citation:'), + new CitationsBlock({ + citations: [ + { + location: { type: 'documentChar', documentIndex: 0, start: 0, end: 5 }, + source: 'doc', + sourceContent: [{ text: 'source text' }], + title: 'Doc', + }, + ], + content: [{ text: 'cited fragment' }], + }), + ], + }) + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + metrics: new AgentMetrics(), + invocationState: {}, + }) + + expect(result.toString()).toBe('Here is a citation:\ncited fragment') + }) + }) + + describe('when called implicitly', () => { + it('works with String() conversion', () => { + const message = new Message({ + role: 'assistant', + content: [new TextBlock('Hello')], + }) + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + metrics: new AgentMetrics(), + invocationState: {}, + }) + + expect(String(result)).toBe('Hello') + }) + + it('works with template literals', () => { + const message = new Message({ + role: 'assistant', + content: [new TextBlock('World')], + }) + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + metrics: new AgentMetrics(), + invocationState: {}, + }) + + expect(`Response: ${result}`).toBe('Response: World') + }) + }) + }) + + describe('contextSize', () => { + it('returns latestContextSize from metrics', () => { + const message = new Message({ + role: 'assistant', + content: [new TextBlock('Hello')], + }) + + const metrics = new AgentMetrics({ latestContextSize: 500 }) + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + metrics, + invocationState: {}, + }) + + expect(result.contextSize).toBe(500) + }) + + it('returns undefined when metrics has no latestContextSize', () => { + const message = new Message({ + role: 'assistant', + content: [new TextBlock('Hello')], + }) + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + metrics: new AgentMetrics(), + invocationState: {}, + }) + + expect(result.contextSize).toBeUndefined() + }) + + it('returns undefined when no metrics are available', () => { + const message = new Message({ + role: 'assistant', + content: [new TextBlock('Hello')], + }) + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + invocationState: {}, + }) + + expect(result.contextSize).toBeUndefined() + }) + }) + + describe('projectedContextSize', () => { + it('returns projectedContextSize from metrics', () => { + const message = new Message({ + role: 'assistant', + content: [new TextBlock('Hello')], + }) + + const metrics = new AgentMetrics({ projectedContextSize: 750 }) + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + metrics, + invocationState: {}, + }) + + expect(result.projectedContextSize).toBe(750) + }) + + it('returns undefined when metrics has no projectedContextSize', () => { + const message = new Message({ + role: 'assistant', + content: [new TextBlock('Hello')], + }) + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + metrics: new AgentMetrics(), + invocationState: {}, + }) + + expect(result.projectedContextSize).toBeUndefined() + }) + }) + + describe('toJSON', () => { + it('excludes traces and metrics from serialization', () => { + const message = new Message({ + role: 'assistant', + content: [new TextBlock('Hello')], + }) + + const traces = [new AgentTrace('Cycle 1')] + const metrics = new AgentMetrics() + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + traces, + metrics, + invocationState: {}, + }) + + const json = result.toJSON() + + expect(json).toEqual({ + type: 'agentResult', + stopReason: 'endTurn', + lastMessage: message, + }) + }) + + it('includes structuredOutput when present', () => { + const message = new Message({ + role: 'assistant', + content: [new TextBlock('Response')], + }) + + const structuredOutput = { field: 'value' } + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + structuredOutput, + invocationState: {}, + }) + + const json = result.toJSON() + + expect(json).toHaveProperty('structuredOutput', structuredOutput) + }) + + it('excludes structuredOutput when not present', () => { + const message = new Message({ + role: 'assistant', + content: [new TextBlock('Response')], + }) + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + invocationState: {}, + }) + + const json = result.toJSON() + + expect(json).not.toHaveProperty('structuredOutput') + }) + + it('is automatically used by JSON.stringify()', () => { + const message = new Message({ + role: 'assistant', + content: [new TextBlock('Hello')], + }) + + const traces = [new AgentTrace('Cycle 1')] + const metrics = new AgentMetrics() + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + traces, + metrics, + invocationState: {}, + }) + + const jsonString = JSON.stringify(result) + const parsed = JSON.parse(jsonString) + + expect(parsed).toEqual({ + type: 'agentResult', + stopReason: 'endTurn', + lastMessage: expect.objectContaining({ + role: 'assistant', + content: expect.any(Array), + }), + }) + }) + + it('preserves traces and metrics as accessible properties', () => { + const message = new Message({ + role: 'assistant', + content: [new TextBlock('Hello')], + }) + + const traces = [new AgentTrace('Cycle 1')] + const metrics = new AgentMetrics() + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + traces, + metrics, + invocationState: {}, + }) + + // Properties are still accessible + expect({ traces: result.traces, metrics: result.metrics }).toEqual({ + traces, + metrics, + }) + + // But not in JSON + const json = result.toJSON() + expect(json).toEqual({ + type: 'agentResult', + stopReason: 'endTurn', + lastMessage: message, + }) + }) + + it('prevents bloated API responses when forwarding result directly', () => { + const message = new Message({ + role: 'assistant', + content: [new TextBlock('Response text')], + }) + + // Simulate large traces and metrics from real agent execution + const traces = [new AgentTrace('Cycle 1'), new AgentTrace('Cycle 2'), new AgentTrace('Cycle 3')] + const metrics = new AgentMetrics() + + const result = new AgentResult({ + stopReason: 'endTurn', + lastMessage: message, + traces, + metrics, + invocationState: {}, + }) + + // Simulate what happens in Express/Next.js: res.json(result) + const apiResponse = JSON.parse(JSON.stringify(result)) + + // Verify API response is lean - no traces/metrics bloat + expect(apiResponse).toEqual({ + type: 'agentResult', + stopReason: 'endTurn', + lastMessage: expect.objectContaining({ + role: 'assistant', + content: expect.any(Array), + }), + }) + expect(apiResponse).not.toHaveProperty('traces') + expect(apiResponse).not.toHaveProperty('metrics') + }) + }) +}) diff --git a/strands-ts/src/types/__tests__/citations.test.ts b/strands-ts/src/types/__tests__/citations.test.ts new file mode 100644 index 0000000000..7aa858d48c --- /dev/null +++ b/strands-ts/src/types/__tests__/citations.test.ts @@ -0,0 +1,115 @@ +import { describe, expect, it } from 'vitest' +import { CitationsBlock, type CitationsBlockData } from '../citations.js' + +describe('CitationsBlock', () => { + const singleCitationData: CitationsBlockData = { + citations: [ + { + location: { type: 'documentChar', documentIndex: 0, start: 10, end: 50 }, + source: 'doc-0', + sourceContent: [{ text: 'source text from document' }], + title: 'Test Document', + }, + ], + content: [{ text: 'generated text with citation' }], + } + + const allVariantsData: CitationsBlockData = { + citations: [ + { + location: { type: 'documentChar', documentIndex: 0, start: 150, end: 300 }, + source: 'doc-0', + sourceContent: [{ text: 'char source' }], + title: 'Text Document', + }, + { + location: { type: 'documentPage', documentIndex: 0, start: 2, end: 3 }, + source: 'doc-0', + sourceContent: [{ text: 'page source' }], + title: 'PDF Document', + }, + { + location: { type: 'documentChunk', documentIndex: 1, start: 5, end: 8 }, + source: 'doc-1', + sourceContent: [{ text: 'chunk source' }], + title: 'Chunked Document', + }, + { + location: { type: 'searchResult', searchResultIndex: 0, start: 25, end: 150 }, + source: 'search-0', + sourceContent: [{ text: 'search source' }], + title: 'Search Result', + }, + { + location: { type: 'web', url: 'https://example.com/doc', domain: 'example.com' }, + source: 'web-0', + sourceContent: [{ text: 'web source' }, { text: 'additional source' }], + title: 'Web Page', + }, + ], + content: [{ text: 'first generated' }, { text: 'second generated' }], + } + + it('creates block with correct type discriminator', () => { + const block = new CitationsBlock(singleCitationData) + expect(block.type).toBe('citationsBlock') + }) + + it('stores citations and content', () => { + const block = new CitationsBlock(singleCitationData) + expect(block.citations).toStrictEqual(singleCitationData.citations) + expect(block.content).toStrictEqual(singleCitationData.content) + }) + + it('round-trips all CitationLocation variants, multiple citations, and multiple content blocks', () => { + const original = new CitationsBlock(allVariantsData) + const json = original.toJSON() + const restored = CitationsBlock.fromJSON(json) + + expect(restored).toEqual(original) + expect(restored.citations).toHaveLength(5) + + expect(restored.citations[0]!.location.type).toBe('documentChar') + expect(restored.citations[1]!.location.type).toBe('documentPage') + expect(restored.citations[2]!.location.type).toBe('documentChunk') + expect(restored.citations[3]!.location.type).toBe('searchResult') + expect(restored.citations[4]!.location.type).toBe('web') + + // Verify web-specific optional domain field survives round-trip + const webLoc = restored.citations[4]!.location + if (webLoc.type === 'web') { + expect(webLoc.domain).toBe('example.com') + } + }) + + it('handles empty arrays', () => { + const data: CitationsBlockData = { + citations: [], + content: [], + } + const block = new CitationsBlock(data) + expect(block.citations).toStrictEqual([]) + expect(block.content).toStrictEqual([]) + + const restored = CitationsBlock.fromJSON(block.toJSON()) + expect(restored).toEqual(block) + }) + + it('toJSON returns wrapped format', () => { + const block = new CitationsBlock(singleCitationData) + const json = block.toJSON() + expect(json).toStrictEqual({ + citations: { + citations: singleCitationData.citations, + content: singleCitationData.content, + }, + }) + }) + + it('works with JSON.stringify', () => { + const original = new CitationsBlock(allVariantsData) + const jsonString = JSON.stringify(original) + const restored = CitationsBlock.fromJSON(JSON.parse(jsonString)) + expect(restored).toEqual(original) + }) +}) diff --git a/strands-ts/src/types/__tests__/json.test.ts b/strands-ts/src/types/__tests__/json.test.ts new file mode 100644 index 0000000000..c55d116f4c --- /dev/null +++ b/strands-ts/src/types/__tests__/json.test.ts @@ -0,0 +1,383 @@ +import { describe, it, expect } from 'vitest' +import { deepCopy, deepCopyWithValidation } from '../json.js' +import { JsonValidationError } from '../../errors.js' + +describe('deepCopy', () => { + describe('primitive values', () => { + it('copies strings', () => { + const result = deepCopy('hello') + expect(result).toBe('hello') + }) + + it('copies numbers', () => { + const result = deepCopy(42) + expect(result).toBe(42) + }) + + it('copies booleans', () => { + const result = deepCopy(true) + expect(result).toBe(true) + }) + + it('copies null', () => { + const result = deepCopy(null) + expect(result).toBe(null) + }) + }) + + describe('object values', () => { + it('creates a deep copy of objects', () => { + const original = { nested: { value: 'test' } } + const copy = deepCopy(original) + + expect(copy).toEqual(original) + expect(copy).not.toBe(original) // Different reference + + // Verify deep copy - modifying original shouldn't affect copy + original.nested.value = 'changed' + expect((copy as { nested: { value: string } }).nested.value).toBe('test') + }) + + it('copies empty objects', () => { + const result = deepCopy({}) + expect(result).toEqual({}) + }) + + it('copies objects with multiple properties', () => { + const original = { a: 1, b: 'two', c: true, d: null } + const copy = deepCopy(original) + expect(copy).toEqual(original) + }) + }) + + describe('array values', () => { + it('creates a deep copy of arrays', () => { + const original = [1, 2, 3, { nested: 'value' }] + const copy = deepCopy(original) + + expect(copy).toEqual(original) + expect(copy).not.toBe(original) // Different reference + + // Verify deep copy - modifying original shouldn't affect copy + original[0] = 999 + expect((copy as number[])[0]).toBe(1) + }) + + it('copies empty arrays', () => { + const result = deepCopy([]) + expect(result).toEqual([]) + }) + + it('copies nested arrays', () => { + const original = [ + [1, 2], + [3, 4], + ] + const copy = deepCopy(original) + expect(copy).toEqual(original) + }) + }) + + describe('error handling', () => { + it('throws error for circular references', () => { + const circular: { self?: unknown } = {} + circular.self = circular + + expect(() => deepCopy(circular)).toThrow('Unable to serialize tool result') + }) + + it('silently drops functions from objects', () => { + const withFunction = { + normalProp: 'value', + funcProp: (): string => 'test', + } + + const result = deepCopy(withFunction) + expect(result).toEqual({ normalProp: 'value' }) + expect(result).not.toHaveProperty('funcProp') + }) + + it('silently drops symbols from objects', () => { + const sym = Symbol('test') + const withSymbol = { + normalProp: 'value', + [sym]: 'symbolValue', + } + + const result = deepCopy(withSymbol) + expect(result).toEqual({ normalProp: 'value' }) + // Symbols are dropped during JSON serialization + expect(Object.getOwnPropertySymbols(result as object)).toHaveLength(0) + }) + + it('silently drops undefined values from objects', () => { + const withUndefined = { + normalProp: 'value', + undefinedProp: undefined, + } + + const result = deepCopy(withUndefined) + expect(result).toEqual({ normalProp: 'value' }) + expect(result).not.toHaveProperty('undefinedProp') + }) + }) + + describe('complex nested structures', () => { + it('copies deeply nested structures', () => { + const original = { + level1: { + level2: { + level3: { + array: [1, 2, { deep: 'value' }], + string: 'test', + }, + }, + }, + } + + const copy = deepCopy(original) + expect(copy).toEqual(original) + expect(copy).not.toBe(original) + }) + + it('copies arrays of objects', () => { + const original = [ + { id: 1, name: 'first' }, + { id: 2, name: 'second' }, + { id: 3, name: 'third' }, + ] + + const copy = deepCopy(original) + expect(copy).toEqual(original) + expect(copy).not.toBe(original) + }) + }) +}) + +describe('deepCopyWithValidation', () => { + describe('primitive values', () => { + it('copies strings', () => { + const result = deepCopyWithValidation('hello', 'testValue') + expect(result).toBe('hello') + }) + + it('copies numbers', () => { + const result = deepCopyWithValidation(42, 'testValue') + expect(result).toBe(42) + }) + + it('copies booleans', () => { + const result = deepCopyWithValidation(true, 'testValue') + expect(result).toBe(true) + }) + + it('copies null', () => { + const result = deepCopyWithValidation(null, 'testValue') + expect(result).toBe(null) + }) + }) + + describe('object values', () => { + it('creates a deep copy of objects', () => { + const original = { nested: { value: 'test' } } + const copy = deepCopyWithValidation(original, 'testValue') + + expect(copy).toEqual(original) + expect(copy).not.toBe(original) // Different reference + + // Verify deep copy - modifying original shouldn't affect copy + original.nested.value = 'changed' + expect((copy as { nested: { value: string } }).nested.value).toBe('test') + }) + + it('copies empty objects', () => { + const result = deepCopyWithValidation({}, 'testValue') + expect(result).toEqual({}) + }) + + it('copies objects with multiple properties', () => { + const original = { a: 1, b: 'two', c: true, d: null } + const copy = deepCopyWithValidation(original, 'testValue') + expect(copy).toEqual(original) + }) + }) + + describe('array values', () => { + it('creates a deep copy of arrays', () => { + const original = [1, 2, 3, { nested: 'value' }] + const copy = deepCopyWithValidation(original, 'testValue') + + expect(copy).toEqual(original) + expect(copy).not.toBe(original) // Different reference + + // Verify deep copy - modifying original shouldn't affect copy + original[0] = 999 + expect((copy as number[])[0]).toBe(1) + }) + + it('copies empty arrays', () => { + const result = deepCopyWithValidation([], 'testValue') + expect(result).toEqual([]) + }) + + it('copies nested arrays', () => { + const original = [ + [1, 2], + [3, 4], + ] + const copy = deepCopyWithValidation(original, 'testValue') + expect(copy).toEqual(original) + }) + }) + + describe('validation errors', () => { + it('throws JsonValidationError for functions at top level', () => { + const func = (): string => 'test' + + expect(() => deepCopyWithValidation(func, 'testValue')).toThrow(JsonValidationError) + expect(() => deepCopyWithValidation(func, 'testValue')).toThrow( + 'testValue contains a function which cannot be serialized' + ) + }) + + it('throws JsonValidationError for functions in objects', () => { + const withFunction = { + normalProp: 'value', + funcProp: (): string => 'test', + } + + expect(() => deepCopyWithValidation(withFunction, 'testValue')).toThrow(JsonValidationError) + expect(() => deepCopyWithValidation(withFunction, 'testValue')).toThrow( + 'testValue.funcProp contains a function which cannot be serialized' + ) + }) + + it('throws JsonValidationError for functions in nested objects', () => { + const nested = { + level1: { + level2: { + func: (): string => 'test', + }, + }, + } + + expect(() => deepCopyWithValidation(nested, 'config')).toThrow(JsonValidationError) + expect(() => deepCopyWithValidation(nested, 'config')).toThrow( + 'config.level1.level2.func contains a function which cannot be serialized' + ) + }) + + it('throws JsonValidationError for functions in arrays', () => { + const withFunction = [1, 2, (): string => 'test'] + + expect(() => deepCopyWithValidation(withFunction, 'items')).toThrow(JsonValidationError) + expect(() => deepCopyWithValidation(withFunction, 'items')).toThrow( + 'items[2] contains a function which cannot be serialized' + ) + }) + + it('throws JsonValidationError for symbols in objects', () => { + const sym = Symbol('test') + const withSymbol = { + normalProp: 'value', + symProp: sym, + } + + expect(() => deepCopyWithValidation(withSymbol, 'testValue')).toThrow(JsonValidationError) + expect(() => deepCopyWithValidation(withSymbol, 'testValue')).toThrow( + 'testValue.symProp contains a symbol which cannot be serialized' + ) + }) + + it('throws JsonValidationError for symbols in arrays', () => { + const sym = Symbol('test') + const withSymbol = [1, 2, sym] + + expect(() => deepCopyWithValidation(withSymbol, 'items')).toThrow(JsonValidationError) + expect(() => deepCopyWithValidation(withSymbol, 'items')).toThrow( + 'items[2] contains a symbol which cannot be serialized' + ) + }) + + it('throws JsonValidationError for undefined values in objects', () => { + const withUndefined = { + normalProp: 'value', + undefinedProp: undefined, + } + + expect(() => deepCopyWithValidation(withUndefined, 'testValue')).toThrow(JsonValidationError) + expect(() => deepCopyWithValidation(withUndefined, 'testValue')).toThrow( + 'testValue.undefinedProp is undefined which cannot be serialized' + ) + }) + + it('throws JsonValidationError for undefined values in arrays', () => { + const withUndefined = [1, 2, undefined] + + expect(() => deepCopyWithValidation(withUndefined, 'items')).toThrow(JsonValidationError) + expect(() => deepCopyWithValidation(withUndefined, 'items')).toThrow( + 'items[2] is undefined which cannot be serialized' + ) + }) + + it('throws JsonValidationError for circular references', () => { + const circular: { self?: unknown } = {} + circular.self = circular + + expect(() => deepCopyWithValidation(circular, 'testValue')).toThrow('circular structure') + }) + }) + + describe('complex nested structures', () => { + it('copies deeply nested structures', () => { + const original = { + level1: { + level2: { + level3: { + array: [1, 2, { deep: 'value' }], + string: 'test', + }, + }, + }, + } + + const copy = deepCopyWithValidation(original, 'testValue') + expect(copy).toEqual(original) + expect(copy).not.toBe(original) + }) + + it('copies arrays of objects', () => { + const original = [ + { id: 1, name: 'first' }, + { id: 2, name: 'second' }, + { id: 3, name: 'third' }, + ] + + const copy = deepCopyWithValidation(original, 'testValue') + expect(copy).toEqual(original) + expect(copy).not.toBe(original) + }) + }) + + describe('context path parameter', () => { + it('uses custom context path in error messages', () => { + const withFunction = { + func: (): string => 'test', + } + + expect(() => deepCopyWithValidation(withFunction, 'initialState')).toThrow( + 'initialState.func contains a function which cannot be serialized' + ) + }) + + it('uses default context path when not provided', () => { + const withFunction = { + func: (): string => 'test', + } + + expect(() => deepCopyWithValidation(withFunction)).toThrow( + 'value.func contains a function which cannot be serialized' + ) + }) + }) +}) diff --git a/strands-ts/src/types/__tests__/media.test.ts b/strands-ts/src/types/__tests__/media.test.ts new file mode 100644 index 0000000000..9248e11247 --- /dev/null +++ b/strands-ts/src/types/__tests__/media.test.ts @@ -0,0 +1,530 @@ +import { describe, it, expect } from 'vitest' +import { + S3Location, + ImageBlock, + VideoBlock, + DocumentBlock, + encodeBase64, + decodeBase64, + type ImageBlockData, + type VideoBlockData, + type DocumentBlockData, +} from '../media.js' +import { TextBlock } from '../messages.js' + +describe('S3Location', () => { + it('creates instance with uri only', () => { + const location = new S3Location({ + uri: 's3://my-bucket/image.jpg', + }) + expect(location).toEqual({ + type: 's3', + uri: 's3://my-bucket/image.jpg', + }) + }) + + it('creates instance with uri and bucketOwner', () => { + const location = new S3Location({ + uri: 's3://my-bucket/image.jpg', + bucketOwner: '123456789012', + }) + expect(location).toEqual({ + type: 's3', + uri: 's3://my-bucket/image.jpg', + bucketOwner: '123456789012', + }) + }) +}) + +describe('ImageBlock', () => { + it('creates instance with bytes source', () => { + const bytes = new Uint8Array([1, 2, 3]) + const block = new ImageBlock({ + format: 'jpeg', + source: { bytes }, + }) + expect(block).toEqual({ + type: 'imageBlock', + format: 'jpeg', + source: { type: 'imageSourceBytes', bytes }, + }) + }) + + it('creates instance with S3 location source', () => { + const block = new ImageBlock({ + format: 'png', + source: { + location: { + type: 's3', + uri: 's3://my-bucket/image.png', + bucketOwner: '123456789012', + }, + }, + }) + expect(block).toEqual({ + type: 'imageBlock', + format: 'png', + source: { + type: 'imageSourceS3Location', + location: expect.any(S3Location), + }, + }) + // Assert S3Location was converted to class + const s3Source = block.source as { type: 'imageSourceS3Location'; location: S3Location } + expect(s3Source.location).toBeInstanceOf(S3Location) + expect(s3Source.location.uri).toBe('s3://my-bucket/image.png') + expect(s3Source.location.bucketOwner).toBe('123456789012') + }) + + it('creates instance with URL source', () => { + const block = new ImageBlock({ + format: 'webp', + source: { url: 'https://example.com/image.webp' }, + }) + expect(block).toEqual({ + type: 'imageBlock', + format: 'webp', + source: { type: 'imageSourceUrl', url: 'https://example.com/image.webp' }, + }) + }) + + it('throws error for invalid source', () => { + const data = { + format: 'jpeg', + source: {}, + } as ImageBlockData + expect(() => new ImageBlock(data)).toThrow('Invalid image source') + }) +}) + +describe('VideoBlock', () => { + it('creates instance with bytes source', () => { + const bytes = new Uint8Array([1, 2, 3]) + const block = new VideoBlock({ + format: 'mp4', + source: { bytes }, + }) + expect(block).toEqual({ + type: 'videoBlock', + format: 'mp4', + source: { type: 'videoSourceBytes', bytes }, + }) + }) + + it('creates instance with S3 location source', () => { + const block = new VideoBlock({ + format: 'webm', + source: { + location: { + type: 's3', + uri: 's3://my-bucket/video.webm', + }, + }, + }) + expect(block).toEqual({ + type: 'videoBlock', + format: 'webm', + source: { + type: 'videoSourceS3Location', + location: expect.any(S3Location), + }, + }) + // Assert S3Location was converted to class + const s3Source = block.source as { type: 'videoSourceS3Location'; location: S3Location } + expect(s3Source.location).toBeInstanceOf(S3Location) + expect(s3Source.location.uri).toBe('s3://my-bucket/video.webm') + }) + + it('throws error for invalid source', () => { + const data = { + format: 'mp4', + source: {}, + } as VideoBlockData + expect(() => new VideoBlock(data)).toThrow('Invalid video source') + }) +}) + +describe('DocumentBlock', () => { + it('creates instance with bytes source', () => { + const bytes = new Uint8Array([1, 2, 3]) + const block = new DocumentBlock({ + name: 'document.pdf', + format: 'pdf', + source: { bytes }, + }) + expect(block).toEqual({ + type: 'documentBlock', + name: 'document.pdf', + format: 'pdf', + source: { type: 'documentSourceBytes', bytes }, + }) + }) + + it('creates instance with text source', () => { + const block = new DocumentBlock({ + name: 'note.txt', + format: 'txt', + source: { text: 'Hello world' }, + }) + expect(block).toEqual({ + type: 'documentBlock', + format: 'txt', + name: 'note.txt', + source: { type: 'documentSourceText', text: 'Hello world' }, + }) + }) + + it('creates instance with content source', () => { + const block = new DocumentBlock({ + name: 'report.html', + format: 'html', + source: { + content: [{ text: 'Introduction' }, { text: 'Conclusion' }], + }, + }) + expect(block).toEqual({ + type: 'documentBlock', + name: 'report.html', + format: 'html', + source: { + type: 'documentSourceContentBlock', + content: [expect.any(TextBlock), expect.any(TextBlock)], + }, + }) + // Assert content blocks were converted to TextBlock instances + const contentSource = block.source as { type: 'documentSourceContentBlock'; content: TextBlock[] } + expect(contentSource.content).toHaveLength(2) + expect(contentSource.content[0]).toBeInstanceOf(TextBlock) + expect(contentSource.content[0]!.text).toBe('Introduction') + expect(contentSource.content[1]).toBeInstanceOf(TextBlock) + expect(contentSource.content[1]!.text).toBe('Conclusion') + }) + + it('creates instance with S3 location source', () => { + const block = new DocumentBlock({ + name: 'report.pdf', + format: 'pdf', + source: { + location: { + type: 's3', + uri: 's3://my-bucket/report.pdf', + bucketOwner: '123456789012', + }, + }, + }) + expect(block).toEqual({ + type: 'documentBlock', + name: 'report.pdf', + format: 'pdf', + source: { + type: 'documentSourceS3Location', + location: { + type: 's3', + uri: 's3://my-bucket/report.pdf', + bucketOwner: '123456789012', + }, + }, + }) + }) + + it('creates instance with bytes and filename', () => { + const bytes = new Uint8Array([1, 2, 3]) + const block = new DocumentBlock({ + name: 'upload.pdf', + format: 'pdf', + source: { bytes }, + }) + expect(block).toEqual({ + type: 'documentBlock', + name: 'upload.pdf', + format: 'pdf', + source: { type: 'documentSourceBytes', bytes }, + }) + }) + + it('creates instance with text and filename', () => { + const block = new DocumentBlock({ + name: 'note.txt', + format: 'txt', + source: { text: 'Hello world' }, + }) + expect(block).toEqual({ + type: 'documentBlock', + format: 'txt', + name: 'note.txt', + source: { type: 'documentSourceText', text: 'Hello world' }, + }) + }) + + it('creates instance with citations and context', () => { + const bytes = new Uint8Array([1, 2, 3]) + const block = new DocumentBlock({ + name: 'research.pdf', + format: 'pdf', + source: { bytes }, + citations: { enabled: true }, + context: 'Research paper about AI', + }) + expect(block).toEqual({ + type: 'documentBlock', + name: 'research.pdf', + format: 'pdf', + source: { + type: 'documentSourceBytes', + bytes, + }, + citations: { enabled: true }, + context: 'Research paper about AI', + }) + }) + + it('throws error for invalid source', () => { + const data = { + name: 'doc.pdf', + format: 'pdf', + source: {}, + } as DocumentBlockData + expect(() => new DocumentBlock(data)).toThrow('Invalid document source') + }) +}) + +describe('encodeBase64 and decodeBase64', () => { + it('round-trips empty array', () => { + const original = new Uint8Array([]) + const encoded = encodeBase64(original) + const decoded = decodeBase64(encoded) + expect(decoded).toEqual(original) + }) + + it('round-trips single byte', () => { + const original = new Uint8Array([42]) + const encoded = encodeBase64(original) + const decoded = decodeBase64(encoded) + expect(decoded).toEqual(original) + }) + + it('round-trips multi-byte array', () => { + const original = new Uint8Array([1, 2, 3, 255, 0, 128]) + const encoded = encodeBase64(original) + const decoded = decodeBase64(encoded) + expect(decoded).toEqual(original) + }) + + it('round-trips large array', () => { + const original = new Uint8Array(1000) + for (let i = 0; i < original.length; i++) { + original[i] = i % 256 + } + const encoded = encodeBase64(original) + const decoded = decodeBase64(encoded) + expect(decoded).toEqual(original) + }) +}) + +describe('fromJSON with serialized (base64 string) input', () => { + it('ImageBlock.fromJSON accepts base64 string for bytes', () => { + const originalBytes = new Uint8Array([1, 2, 3, 4, 5]) + const base64String = encodeBase64(originalBytes) + const block = ImageBlock.fromJSON({ + image: { format: 'jpeg', source: { bytes: base64String } }, + }) + expect((block.source as { type: 'imageSourceBytes'; bytes: Uint8Array }).bytes).toEqual(originalBytes) + }) + + it('ImageBlock.fromJSON accepts Uint8Array for bytes', () => { + const originalBytes = new Uint8Array([1, 2, 3, 4, 5]) + const block = ImageBlock.fromJSON({ + image: { format: 'jpeg', source: { bytes: originalBytes } }, + }) + expect((block.source as { type: 'imageSourceBytes'; bytes: Uint8Array }).bytes).toEqual(originalBytes) + }) + + it('VideoBlock.fromJSON accepts base64 string for bytes', () => { + const originalBytes = new Uint8Array([10, 20, 30]) + const base64String = encodeBase64(originalBytes) + const block = VideoBlock.fromJSON({ + video: { format: 'mp4', source: { bytes: base64String } }, + }) + expect((block.source as { type: 'videoSourceBytes'; bytes: Uint8Array }).bytes).toEqual(originalBytes) + }) + + it('VideoBlock.fromJSON accepts Uint8Array for bytes', () => { + const originalBytes = new Uint8Array([10, 20, 30]) + const block = VideoBlock.fromJSON({ + video: { format: 'mp4', source: { bytes: originalBytes } }, + }) + expect((block.source as { type: 'videoSourceBytes'; bytes: Uint8Array }).bytes).toEqual(originalBytes) + }) + + it('DocumentBlock.fromJSON accepts base64 string for bytes', () => { + const originalBytes = new Uint8Array([100, 200]) + const base64String = encodeBase64(originalBytes) + const block = DocumentBlock.fromJSON({ + document: { name: 'doc.pdf', format: 'pdf', source: { bytes: base64String } }, + }) + expect((block.source as { type: 'documentSourceBytes'; bytes: Uint8Array }).bytes).toEqual(originalBytes) + }) + + it('DocumentBlock.fromJSON accepts Uint8Array for bytes', () => { + const originalBytes = new Uint8Array([100, 200]) + const block = DocumentBlock.fromJSON({ + document: { name: 'doc.pdf', format: 'pdf', source: { bytes: originalBytes } }, + }) + expect((block.source as { type: 'documentSourceBytes'; bytes: Uint8Array }).bytes).toEqual(originalBytes) + }) +}) + +describe('S3Location toJSON/fromJSON', () => { + it('round-trips with uri only', () => { + const original = new S3Location({ uri: 's3://bucket/key.jpg' }) + const json = original.toJSON() + const restored = S3Location.fromJSON(json) + expect(restored).toEqual(original) + }) + + it('round-trips with uri and bucketOwner', () => { + const original = new S3Location({ uri: 's3://bucket/key.jpg', bucketOwner: '123456789012' }) + const json = original.toJSON() + const restored = S3Location.fromJSON(json) + expect(restored).toEqual(original) + }) + + it('includes type in JSON output', () => { + const location = new S3Location({ uri: 's3://bucket/key.jpg' }) + const json = location.toJSON() + expect(json).toStrictEqual({ type: 's3', uri: 's3://bucket/key.jpg' }) + expect('bucketOwner' in json).toBe(false) + }) +}) + +describe('ImageBlock toJSON/fromJSON', () => { + it('round-trips with bytes source', () => { + const original = new ImageBlock({ + format: 'jpeg', + source: { bytes: new Uint8Array([1, 2, 3]) }, + }) + const restored = ImageBlock.fromJSON(original.toJSON()) + expect(restored).toEqual(original) + }) + + it('round-trips with url source', () => { + const original = new ImageBlock({ + format: 'png', + source: { url: 'https://example.com/image.png' }, + }) + const restored = ImageBlock.fromJSON(original.toJSON()) + expect(restored).toEqual(original) + }) + + it('round-trips with s3Location source', () => { + const original = new ImageBlock({ + format: 'webp', + source: { location: { type: 's3', uri: 's3://bucket/image.webp', bucketOwner: '123456789012' } }, + }) + const restored = ImageBlock.fromJSON(original.toJSON()) + expect(restored).toEqual(original) + }) + + it('encodes bytes as base64 in JSON output', () => { + const block = new ImageBlock({ + format: 'jpeg', + source: { bytes: new Uint8Array([1, 2, 3]) }, + }) + const json = block.toJSON() + expect(typeof (json.image.source as { bytes: unknown }).bytes).toBe('string') + }) +}) + +describe('VideoBlock toJSON/fromJSON', () => { + it('round-trips with bytes source', () => { + const original = new VideoBlock({ + format: 'mp4', + source: { bytes: new Uint8Array([10, 20, 30]) }, + }) + const restored = VideoBlock.fromJSON(original.toJSON()) + expect(restored).toEqual(original) + }) + + it('round-trips with s3Location source', () => { + const original = new VideoBlock({ + format: 'webm', + source: { location: { type: 's3', uri: 's3://bucket/video.webm' } }, + }) + const restored = VideoBlock.fromJSON(original.toJSON()) + expect(restored).toEqual(original) + }) + + it('encodes bytes as base64 in JSON output', () => { + const block = new VideoBlock({ + format: 'mp4', + source: { bytes: new Uint8Array([1, 2, 3]) }, + }) + const json = block.toJSON() + expect(typeof (json.video.source as { bytes: unknown }).bytes).toBe('string') + }) +}) + +describe('DocumentBlock toJSON/fromJSON', () => { + it('round-trips with bytes source', () => { + const original = new DocumentBlock({ + name: 'doc.pdf', + format: 'pdf', + source: { bytes: new Uint8Array([100, 200]) }, + }) + const restored = DocumentBlock.fromJSON(original.toJSON()) + expect(restored).toEqual(original) + }) + + it('round-trips with text source', () => { + const original = new DocumentBlock({ + name: 'note.txt', + format: 'txt', + source: { text: 'Hello world' }, + }) + const restored = DocumentBlock.fromJSON(original.toJSON()) + expect(restored).toEqual(original) + }) + + it('round-trips with content source', () => { + const original = new DocumentBlock({ + name: 'report.html', + format: 'html', + source: { content: [{ text: 'Introduction' }, { text: 'Conclusion' }] }, + }) + const restored = DocumentBlock.fromJSON(original.toJSON()) + expect(restored).toEqual(original) + }) + + it('round-trips with s3Location source', () => { + const original = new DocumentBlock({ + name: 'report.pdf', + format: 'pdf', + source: { location: { type: 's3', uri: 's3://bucket/report.pdf', bucketOwner: '123456789012' } }, + }) + const restored = DocumentBlock.fromJSON(original.toJSON()) + expect(restored).toEqual(original) + }) + + it('round-trips with citations and context', () => { + const original = new DocumentBlock({ + name: 'research.pdf', + format: 'pdf', + source: { bytes: new Uint8Array([1, 2, 3]) }, + citations: { enabled: true }, + context: 'Research paper about AI', + }) + const restored = DocumentBlock.fromJSON(original.toJSON()) + expect(restored).toEqual(original) + }) + + it('omits undefined citations and context from JSON', () => { + const block = new DocumentBlock({ + name: 'doc.pdf', + format: 'pdf', + source: { bytes: new Uint8Array([1]) }, + }) + const json = block.toJSON() + expect('citations' in json.document).toBe(false) + expect('context' in json.document).toBe(false) + }) +}) diff --git a/strands-ts/src/types/__tests__/messages.test.ts b/strands-ts/src/types/__tests__/messages.test.ts new file mode 100644 index 0000000000..f5ee35f1a1 --- /dev/null +++ b/strands-ts/src/types/__tests__/messages.test.ts @@ -0,0 +1,765 @@ +import { describe, expect, test, it } from 'vitest' +import { + Message, + TextBlock, + ToolUseBlock, + ToolResultBlock, + ReasoningBlock, + CachePointBlock, + GuardContentBlock, + JsonBlock, + type MessageData, + type SystemPromptData, + systemPromptFromData, + systemPromptToData, +} from '../messages.js' +import { ImageBlock, VideoBlock, DocumentBlock, encodeBase64 } from '../media.js' +import { CitationsBlock } from '../citations.js' + +describe('Message', () => { + test('creates message with role and content', () => { + const content = [new TextBlock('test')] + const message = new Message({ role: 'user', content }) + + expect(message).toEqual({ + type: 'message', + role: 'user', + content, + }) + }) +}) + +describe('Message metadata', () => { + test('creates message without metadata', () => { + const message = new Message({ role: 'user', content: [new TextBlock('test')] }) + expect(message.metadata).toBeUndefined() + }) + + test('creates message with metadata', () => { + const metadata = { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + metrics: { latencyMs: 100 }, + } + const message = new Message({ role: 'assistant', content: [new TextBlock('hello')], metadata }) + expect(message.metadata).toStrictEqual(metadata) + }) + + test('creates message with custom metadata', () => { + const metadata = { + custom: { source: 'summarization', originalTurns: [5, 6, 7] }, + } + const message = new Message({ role: 'assistant', content: [new TextBlock('summary')], metadata }) + expect(message.metadata).toStrictEqual(metadata) + }) + + test('toJSON includes metadata when present', () => { + const metadata = { + usage: { inputTokens: 42, outputTokens: 10, totalTokens: 52 }, + metrics: { latencyMs: 200 }, + } + const message = new Message({ role: 'assistant', content: [new TextBlock('test')], metadata }) + const json = message.toJSON() + expect(json.metadata).toStrictEqual(metadata) + }) + + test('toJSON omits metadata when not present', () => { + const message = new Message({ role: 'user', content: [new TextBlock('test')] }) + const json = message.toJSON() + expect('metadata' in json).toBe(false) + }) + + test('fromMessageData preserves metadata', () => { + const data: MessageData = { + role: 'assistant', + content: [{ text: 'hello' }], + metadata: { + usage: { inputTokens: 10, outputTokens: 5, totalTokens: 15 }, + metrics: { latencyMs: 100 }, + }, + } + const message = Message.fromMessageData(data) + expect(message.metadata).toStrictEqual(data.metadata) + }) + + test('fromMessageData works without metadata', () => { + const data: MessageData = { + role: 'user', + content: [{ text: 'hello' }], + } + const message = Message.fromMessageData(data) + expect(message.metadata).toBeUndefined() + }) + + test('round-trips metadata through toJSON/fromJSON', () => { + const metadata = { + usage: { inputTokens: 42, outputTokens: 10, totalTokens: 52 }, + metrics: { latencyMs: 200 }, + custom: { source: 'test' }, + } + const original = new Message({ role: 'assistant', content: [new TextBlock('test')], metadata }) + const restored = Message.fromJSON(original.toJSON()) + expect(restored.metadata).toStrictEqual(metadata) + }) + + test('round-trips metadata through JSON.stringify/parse', () => { + const metadata = { + usage: { inputTokens: 1, outputTokens: 1, totalTokens: 2 }, + } + const original = new Message({ role: 'assistant', content: [new TextBlock('test')], metadata }) + const jsonString = JSON.stringify(original) + const restored = Message.fromJSON(JSON.parse(jsonString)) + expect(restored.metadata).toStrictEqual(metadata) + }) +}) + +describe('TextBlock', () => { + test('creates text block with text', () => { + const block = new TextBlock('hello') + + expect(block).toEqual({ + type: 'textBlock', + text: 'hello', + }) + }) +}) + +describe('ToolUseBlock', () => { + test('creates tool use block', () => { + const block = new ToolUseBlock({ + name: 'test-tool', + toolUseId: '123', + input: { param: 'value' }, + }) + + expect(block).toEqual({ + type: 'toolUseBlock', + name: 'test-tool', + toolUseId: '123', + input: { param: 'value' }, + }) + }) +}) + +describe('ToolResultBlock', () => { + test('creates tool result block', () => { + const block = new ToolResultBlock({ + toolUseId: '123', + status: 'success', + content: [new TextBlock('result')], + }) + + expect(block).toEqual({ + type: 'toolResultBlock', + toolUseId: '123', + status: 'success', + content: [new TextBlock('result')], + }) + }) +}) + +describe('ReasoningBlock', () => { + test('creates reasoning block with text', () => { + const block = new ReasoningBlock({ text: 'thinking...' }) + + expect(block).toEqual({ + type: 'reasoningBlock', + text: 'thinking...', + }) + }) +}) + +describe('CachePointBlock', () => { + test('creates cache point block', () => { + const block = new CachePointBlock({ cacheType: 'default' }) + + expect(block).toEqual({ + type: 'cachePointBlock', + cacheType: 'default', + }) + }) + + test('creates cache point block with ttl', () => { + const block = new CachePointBlock({ cacheType: 'default', ttl: '1h' }) + + expect(block).toEqual({ + type: 'cachePointBlock', + cacheType: 'default', + ttl: '1h', + }) + }) + + test('serializes ttl in toJSON', () => { + const block = new CachePointBlock({ cacheType: 'default', ttl: '5m' }) + + expect(block.toJSON()).toEqual({ + cachePoint: { cacheType: 'default', ttl: '5m' }, + }) + }) + + test('omits ttl in toJSON when not set', () => { + const block = new CachePointBlock({ cacheType: 'default' }) + + expect(block.toJSON()).toEqual({ + cachePoint: { cacheType: 'default' }, + }) + }) + + test('roundtrips ttl via fromJSON', () => { + const block = CachePointBlock.fromJSON({ cachePoint: { cacheType: 'default', ttl: '1h' } }) + + expect(block).toEqual({ + type: 'cachePointBlock', + cacheType: 'default', + ttl: '1h', + }) + }) +}) + +describe('JsonBlock', () => { + test('creates json block', () => { + const block = new JsonBlock({ json: { key: 'value' } }) + + expect(block).toEqual({ + type: 'jsonBlock', + json: { key: 'value' }, + }) + }) +}) + +describe('Message.fromMessageData', () => { + it('converts text block data to TextBlock', () => { + const messageData: MessageData = { + role: 'user', + content: [{ text: 'hello world' }], + } + const message = Message.fromMessageData(messageData) + expect(message.content).toHaveLength(1) + expect(message.content[0]).toEqual(new TextBlock('hello world')) + }) + + it('converts tool use block data to ToolUseBlock', () => { + const messageData: MessageData = { + role: 'assistant', + content: [ + { + toolUse: { + toolUseId: 'tool-123', + name: 'test-tool', + input: { key: 'value' }, + }, + }, + ], + } + const message = Message.fromMessageData(messageData) + expect(message.content).toHaveLength(1) + expect(message.content[0]).toBeInstanceOf(ToolUseBlock) + expect(message.content[0]!.type).toBe('toolUseBlock') + }) + + it('converts tool result block data to ToolResultBlock with text content', () => { + const messageData: MessageData = { + role: 'user', + content: [ + { + toolResult: { + toolUseId: 'tool-123', + status: 'success', + content: [{ text: 'result text' }], + }, + }, + ], + } + const message = Message.fromMessageData(messageData) + expect(message.content).toHaveLength(1) + expect(message.content[0]).toBeInstanceOf(ToolResultBlock) + const toolResultBlock = message.content[0] as ToolResultBlock + expect(toolResultBlock.content).toHaveLength(1) + expect(toolResultBlock.content[0]).toBeInstanceOf(TextBlock) + }) + + it('converts tool result block data to ToolResultBlock with json content', () => { + const messageData: MessageData = { + role: 'user', + content: [ + { + toolResult: { + toolUseId: 'tool-123', + status: 'success', + content: [{ json: { result: 'value' } }], + }, + }, + ], + } + const message = Message.fromMessageData(messageData) + expect(message.content).toHaveLength(1) + const toolResultBlock = message.content[0] as ToolResultBlock + expect(toolResultBlock.content).toHaveLength(1) + expect(toolResultBlock.content[0]).toBeInstanceOf(JsonBlock) + }) + + it('converts reasoning block data to ReasoningBlock', () => { + const messageData: MessageData = { + role: 'assistant', + content: [ + { + reasoning: { text: 'thinking about it...' }, + }, + ], + } + const message = Message.fromMessageData(messageData) + expect(message.content).toHaveLength(1) + expect(message.content[0]).toBeInstanceOf(ReasoningBlock) + expect(message.content[0]!.type).toBe('reasoningBlock') + }) + + it('converts cache point block data to CachePointBlock', () => { + const messageData: MessageData = { + role: 'user', + content: [ + { + cachePoint: { cacheType: 'default' }, + }, + ], + } + const message = Message.fromMessageData(messageData) + expect(message.content).toHaveLength(1) + expect(message.content[0]).toBeInstanceOf(CachePointBlock) + expect(message.content[0]!.type).toBe('cachePointBlock') + }) + + it('converts guard content block data to GuardContentBlock', () => { + const messageData: MessageData = { + role: 'user', + content: [ + { + guardContent: { + text: { + text: 'guard this content', + qualifiers: ['guard_content'], + }, + }, + }, + ], + } + const message = Message.fromMessageData(messageData) + expect(message.content).toHaveLength(1) + expect(message.content[0]!.type).toBe('guardContentBlock') + }) + + it('converts image block data to ImageBlock', () => { + const messageData: MessageData = { + role: 'user', + content: [ + { + image: { + format: 'jpeg', + source: { bytes: new Uint8Array([1, 2, 3]) }, + }, + }, + ], + } + const message = Message.fromMessageData(messageData) + expect(message.content).toHaveLength(1) + expect(message.content[0]).toBeInstanceOf(ImageBlock) + expect(message.content[0]!.type).toBe('imageBlock') + }) + + it('converts video block data to VideoBlock', () => { + const messageData: MessageData = { + role: 'user', + content: [ + { + video: { + format: 'mp4', + source: { bytes: new Uint8Array([1, 2, 3]) }, + }, + }, + ], + } + const message = Message.fromMessageData(messageData) + expect(message.content).toHaveLength(1) + expect(message.content[0]).toBeInstanceOf(VideoBlock) + expect(message.content[0]!.type).toBe('videoBlock') + }) + + it('converts document block data to DocumentBlock', () => { + const messageData: MessageData = { + role: 'user', + content: [ + { + document: { + name: 'test.pdf', + format: 'pdf', + source: { bytes: new Uint8Array([1, 2, 3]) }, + }, + }, + ], + } + const message = Message.fromMessageData(messageData) + expect(message.content).toHaveLength(1) + expect(message.content[0]).toBeInstanceOf(DocumentBlock) + expect(message.content[0]!.type).toBe('documentBlock') + }) + + it('converts citations content block data to CitationsBlock', () => { + const messageData: MessageData = { + role: 'assistant', + content: [ + { + citations: { + citations: [ + { + location: { type: 'documentChar', documentIndex: 0, start: 10, end: 50 }, + source: 'doc-0', + sourceContent: [{ text: 'source text' }], + title: 'Test Doc', + }, + ], + content: [{ text: 'generated text' }], + }, + }, + ], + } + const message = Message.fromMessageData(messageData) + expect(message.content).toHaveLength(1) + expect(message.content[0]).toBeInstanceOf(CitationsBlock) + expect(message.content[0]!.type).toBe('citationsBlock') + }) + + it('converts multiple content blocks', () => { + const messageData: MessageData = { + role: 'user', + content: [ + { text: 'first block' }, + { image: { format: 'png', source: { bytes: new Uint8Array([1, 2, 3]) } } }, + { text: 'second block' }, + ], + } + const message = Message.fromMessageData(messageData) + expect(message.content).toHaveLength(3) + expect(message.content[0]).toBeInstanceOf(TextBlock) + expect(message.content[1]).toBeInstanceOf(ImageBlock) + expect(message.content[2]).toBeInstanceOf(TextBlock) + }) + + it('throws error for unknown content block type', () => { + const messageData = { + role: 'user', + content: [{ unknownType: { data: 'value' } }], + } as unknown as MessageData + expect(() => Message.fromMessageData(messageData)).toThrow('Unknown ContentBlockData type') + }) +}) + +describe('systemPromptFromData', () => { + describe('when called with string', () => { + it('returns the string unchanged', () => { + const data: SystemPromptData = 'You are a helpful assistant' + const result = systemPromptFromData(data) + expect(result).toBe('You are a helpful assistant') + }) + }) + + describe('when called with TextBlockData', () => { + it('converts to TextBlock', () => { + const data: SystemPromptData = [{ text: 'System prompt text' }] + const result = systemPromptFromData(data) + expect(result).toEqual([new TextBlock('System prompt text')]) + }) + }) + + describe('when called with CachePointBlockData', () => { + it('converts to CachePointBlock', () => { + const data: SystemPromptData = [{ text: 'prompt' }, { cachePoint: { cacheType: 'default' } }] + const result = systemPromptFromData(data) + expect(result).toEqual([new TextBlock('prompt'), new CachePointBlock({ cacheType: 'default' })]) + }) + }) + + describe('when called with GuardContentBlockData', () => { + it('converts to GuardContentBlock', () => { + const data: SystemPromptData = [ + { + guardContent: { + text: { + text: 'guard this content', + qualifiers: ['guard_content'], + }, + }, + }, + ] + const result = systemPromptFromData(data) + expect(result).toEqual([ + new GuardContentBlock({ + text: { + text: 'guard this content', + qualifiers: ['guard_content'], + }, + }), + ]) + }) + }) + + describe('when called with mixed content blocks', () => { + it('converts all block types correctly', () => { + const data: SystemPromptData = [ + { text: 'First text block' }, + { cachePoint: { cacheType: 'default' } }, + { text: 'Second text block' }, + { + guardContent: { + text: { + text: 'guard content', + qualifiers: ['guard_content'], + }, + }, + }, + ] + const result = systemPromptFromData(data) + expect(result).toEqual([ + new TextBlock('First text block'), + new CachePointBlock({ cacheType: 'default' }), + new TextBlock('Second text block'), + new GuardContentBlock({ + text: { + text: 'guard content', + qualifiers: ['guard_content'], + }, + }), + ]) + }) + }) + + describe('when called with empty array', () => { + it('returns empty array', () => { + const data: SystemPromptData = [] + const result = systemPromptFromData(data) + expect(result).toEqual([]) + }) + }) + + describe('when called with unknown block type', () => { + it('throws error', () => { + const data = [{ unknownType: { data: 'value' } }] as unknown as SystemPromptData + expect(() => systemPromptFromData(data)).toThrow('Unknown SystemContentBlockData type') + }) + }) + + describe('when called with class instances', () => { + it('returns them unchanged', () => { + const systemPrompt = [new TextBlock('prompt'), new CachePointBlock({ cacheType: 'default' })] + const result = systemPromptFromData(systemPrompt) + expect(result).toEqual(systemPrompt) + }) + }) +}) + +describe('systemPromptToData', () => { + describe('when called with string', () => { + it('returns the string unchanged', () => { + const prompt = 'You are a helpful assistant' + const result = systemPromptToData(prompt) + expect(result).toBe('You are a helpful assistant') + }) + }) + + describe('when called with TextBlock array', () => { + it('converts to TextBlockData array', () => { + const prompt = [new TextBlock('System prompt text')] + const result = systemPromptToData(prompt) + expect(result).toEqual([{ text: 'System prompt text' }]) + }) + }) + + describe('when called with CachePointBlock array', () => { + it('converts to CachePointBlockData array', () => { + const prompt = [new TextBlock('prompt'), new CachePointBlock({ cacheType: 'default' })] + const result = systemPromptToData(prompt) + expect(result).toEqual([{ text: 'prompt' }, { cachePoint: { cacheType: 'default' } }]) + }) + }) + + describe('when called with GuardContentBlock array', () => { + it('converts to GuardContentBlockData array', () => { + const prompt = [ + new GuardContentBlock({ + text: { + text: 'guard this content', + qualifiers: ['guard_content'], + }, + }), + ] + const result = systemPromptToData(prompt) + expect(result).toEqual([ + { + guardContent: { + text: { + text: 'guard this content', + qualifiers: ['guard_content'], + }, + }, + }, + ]) + }) + }) + + describe('when called with mixed content blocks', () => { + it('converts all block types correctly', () => { + const prompt = [ + new TextBlock('First text block'), + new CachePointBlock({ cacheType: 'default' }), + new TextBlock('Second text block'), + new GuardContentBlock({ + text: { + text: 'guard content', + qualifiers: ['guard_content'], + }, + }), + ] + const result = systemPromptToData(prompt) + expect(result).toEqual([ + { text: 'First text block' }, + { cachePoint: { cacheType: 'default' } }, + { text: 'Second text block' }, + { + guardContent: { + text: { + text: 'guard content', + qualifiers: ['guard_content'], + }, + }, + }, + ]) + }) + }) + + describe('when called with empty array', () => { + it('returns empty array', () => { + const prompt: (TextBlock | CachePointBlock | GuardContentBlock)[] = [] + const result = systemPromptToData(prompt) + expect(result).toEqual([]) + }) + }) + + describe('round-trip conversion', () => { + it('preserves data through toData/fromData cycle', () => { + const original = [new TextBlock('prompt text'), new CachePointBlock({ cacheType: 'default' })] + const data = systemPromptToData(original) + const restored = systemPromptFromData(data) + expect(restored).toEqual(original) + }) + + it('preserves string through toData/fromData cycle', () => { + const original = 'Simple string prompt' + const data = systemPromptToData(original) + const restored = systemPromptFromData(data) + expect(restored).toBe(original) + }) + }) +}) + +describe('toJSON/fromJSON round-trips', () => { + // prettier-ignore + const roundTripCases = [ + ['TextBlock', () => new TextBlock('Hello world')], + ['ToolUseBlock without reasoningSignature',() => new ToolUseBlock({ name: 'test-tool', toolUseId: '123', input: { param: 'value' } })], + ['ToolUseBlock with reasoningSignature', () => new ToolUseBlock({ name: 'test-tool', toolUseId: '123', input: { param: 'value' }, reasoningSignature: 'sig123' })], + ['ToolResultBlock with text content', () => new ToolResultBlock({ toolUseId: '123', status: 'success', content: [new TextBlock('Result text')] })], + ['ToolResultBlock with json content', () => new ToolResultBlock({ toolUseId: '456', status: 'success', content: [new JsonBlock({ json: { result: 'data' } })] })], + ['ToolResultBlock with error status', () => new ToolResultBlock({ toolUseId: '789', status: 'error', content: [new TextBlock('Error message')] })], + ['ReasoningBlock with text only', () => new ReasoningBlock({ text: 'Thinking...' })], + ['ReasoningBlock with signature', () => new ReasoningBlock({ text: 'Thinking...', signature: 'sig123' })], + ['ReasoningBlock with redactedContent', () => new ReasoningBlock({ redactedContent: new Uint8Array([1, 2, 3]) })], + ['CachePointBlock', () => new CachePointBlock({ cacheType: 'default' })], + ['JsonBlock', () => new JsonBlock({ json: { key: 'value', nested: { a: 1 } } })], + ['GuardContentBlock with text', () => new GuardContentBlock({ text: { text: 'Guard this', qualifiers: ['guard_content'] } })], + ['GuardContentBlock with image', () => new GuardContentBlock({ image: { format: 'png', source: { bytes: new Uint8Array([1, 2, 3]) } } })], + ['Message with text content', () => new Message({ role: 'user', content: [new TextBlock('Hello')] })], + ['Message with multiple content blocks', () => new Message({ role: 'assistant', content: [new TextBlock('Here is the result'), new ToolUseBlock({ name: 'test-tool', toolUseId: '123', input: { key: 'value' } })] })], + ['Message with image content', () => new Message({ role: 'user', content: [new TextBlock('Check this image'), new ImageBlock({ format: 'png', source: { bytes: new Uint8Array([1, 2, 3]) } })] })], + ['CitationsBlock', () => new CitationsBlock({ citations: [{ location: { type: 'documentChar', documentIndex: 0, start: 0, end: 10 }, source: 'doc-0', sourceContent: [{ text: 'source' }], title: 'Test' }], content: [{ text: 'generated' }] })], + ] as const + + it.each(roundTripCases)('%s', (_name, createBlock) => { + const original = createBlock() + // Use duck-typing here + const BlockClass = original.constructor as unknown as { fromJSON(json: unknown): unknown } + const restored = BlockClass.fromJSON(original.toJSON()) + expect(restored).toEqual(original) + }) + + it('Message works with JSON.stringify', () => { + const original = new Message({ role: 'user', content: [new TextBlock('Test')] }) + const jsonString = JSON.stringify(original) + const restored = Message.fromJSON(JSON.parse(jsonString)) + expect(restored).toEqual(original) + }) +}) + +describe('fromJSON with serialized (base64 string) input', () => { + it('ReasoningBlock.fromJSON accepts base64 string for redactedContent', () => { + const originalBytes = new Uint8Array([1, 2, 3, 4, 5]) + const base64String = encodeBase64(originalBytes) + const block = ReasoningBlock.fromJSON({ + reasoning: { redactedContent: base64String }, + }) + expect(block.redactedContent).toEqual(originalBytes) + }) + + it('GuardContentBlock.fromJSON accepts base64 string for image bytes', () => { + const originalBytes = new Uint8Array([10, 20, 30]) + const base64String = encodeBase64(originalBytes) + const block = GuardContentBlock.fromJSON({ + guardContent: { + image: { format: 'png', source: { bytes: base64String } }, + }, + }) + expect(block.image?.source.bytes).toEqual(originalBytes) + }) +}) + +describe('toJSON format', () => { + it('TextBlock returns unwrapped format', () => { + const block = new TextBlock('Test') + expect(block.toJSON()).toStrictEqual({ text: 'Test' }) + }) + + it('JsonBlock returns unwrapped format', () => { + const block = new JsonBlock({ json: { test: true } }) + expect(block.toJSON()).toStrictEqual({ json: { test: true } }) + }) + + it('ToolUseBlock omits undefined reasoningSignature', () => { + const block = new ToolUseBlock({ name: 'test-tool', toolUseId: '123', input: {} }) + expect('reasoningSignature' in block.toJSON().toolUse).toBe(false) + }) + + it('ToolResultBlock does not serialize error field', () => { + const block = new ToolResultBlock({ + toolUseId: '123', + status: 'error', + content: [new TextBlock('Error')], + error: new Error('Test error'), + }) + expect('error' in block.toJSON().toolResult).toBe(false) + }) + + it('ReasoningBlock encodes redactedContent as base64', () => { + const block = new ReasoningBlock({ redactedContent: new Uint8Array([1, 2, 3]) }) + expect(typeof block.toJSON().reasoning.redactedContent).toBe('string') + }) + + it('ReasoningBlock omits undefined fields', () => { + const block = new ReasoningBlock({ text: 'Test' }) + const json = block.toJSON() + expect('signature' in json.reasoning).toBe(false) + expect('redactedContent' in json.reasoning).toBe(false) + }) + + it('GuardContentBlock encodes image bytes as base64', () => { + const block = new GuardContentBlock({ + image: { format: 'jpeg', source: { bytes: new Uint8Array([1, 2, 3]) } }, + }) + expect(typeof block.toJSON().guardContent.image?.source.bytes).toBe('string') + }) +}) diff --git a/strands-ts/src/types/__tests__/validation.test.ts b/strands-ts/src/types/__tests__/validation.test.ts new file mode 100644 index 0000000000..e6e7f700fb --- /dev/null +++ b/strands-ts/src/types/__tests__/validation.test.ts @@ -0,0 +1,34 @@ +import { describe, it, expect } from 'vitest' +import { ensureDefined } from '../validation.js' + +describe('ensureDefined', () => { + describe('when value is defined', () => { + it('returns the value', () => { + const value = 'test' + const result = ensureDefined(value, 'testField') + expect(result).toBe('test') + }) + + it('returns zero', () => { + const result = ensureDefined(0, 'numberField') + expect(result).toBe(0) + }) + + it('returns empty string', () => { + const result = ensureDefined('', 'stringField') + expect(result).toBe('') + }) + }) + + describe('when value is null', () => { + it('throws error with field name', () => { + expect(() => ensureDefined(null, 'testField')).toThrow('Expected testField to be defined, but got null') + }) + }) + + describe('when value is undefined', () => { + it('throws error with field name', () => { + expect(() => ensureDefined(undefined, 'testField')).toThrow('Expected testField to be defined, but got undefined') + }) + }) +}) diff --git a/strands-ts/src/types/agent.ts b/strands-ts/src/types/agent.ts new file mode 100644 index 0000000000..a842feab70 --- /dev/null +++ b/strands-ts/src/types/agent.ts @@ -0,0 +1,481 @@ +import type { StateStore } from '../state-store.js' +import type { ContentBlock, ContentBlockData, Message, MessageData, StopReason, SystemPrompt } from './messages.js' +import type { Interrupt } from '../interrupt.js' +import type { InterruptResponseContent, InterruptResponseContentData } from './interrupt.js' +import type { AgentTrace } from '../telemetry/tracer.js' +import type { Snapshot } from './snapshot.js' +import type { TakeSnapshotOptions } from '../agent/snapshot.js' +import type { + BeforeInvocationEvent, + AfterInvocationEvent, + BeforeModelCallEvent, + AfterModelCallEvent, + BeforeToolsEvent, + AfterToolsEvent, + BeforeToolCallEvent, + AfterToolCallEvent, + MessageAddedEvent, + ModelStreamUpdateEvent, + ContentBlockEvent, + ModelMessageEvent, + ToolResultEvent, + ToolStreamUpdateEvent, + AgentResultEvent, + InterruptEvent, + HookableEvent, + StreamEvent, +} from '../hooks/events.js' +import type { HookCallback, HookableEventConstructor, HookCallbackOptions, HookCleanup } from '../hooks/types.js' +import type { ToolRegistry } from '../registry/tool-registry.js' +import type { Model } from '../models/model.js' +import type { z } from 'zod' +import { AgentMetrics } from '../telemetry/meter.js' + +/** + * Arguments for invoking an agent. + * + * Supports multiple input formats: + * - `string` - User text input (wrapped in TextBlock, creates user Message) + * - `ContentBlock[]` | `ContentBlockData[]` - Array of content blocks (creates single user Message) + * - `Message[]` | `MessageData[]` - Array of messages (appends all to conversation) + * - `InterruptResponseContent[]` - Array of interrupt responses (resumes from interrupted state) + */ +export type InvokeArgs = + | string + | ContentBlock[] + | ContentBlockData[] + | Message[] + | MessageData[] + | InterruptResponseContent[] + | InterruptResponseContentData[] + +/** + * Per-invocation state threaded through hooks and tools for a single agent + * invocation, and returned on {@link AgentResult.invocationState}. One object + * per invocation, shared by reference; mutations by hooks or tools are visible + * to subsequent hooks, tools, and recursive loop cycles. + * + * Typically used for request-scoped context (`userId`, `requestId`, `traceId`) + * or cross-hook counters. The core agent loop writes no keys into it — the + * key space is the caller's. Transport bridges may populate reserved keys + * (e.g. `A2AExecutor` sets `a2aRequestContext`); those bridges document their + * own reserved keys. + * + * Distinct from {@link LocalAgent.appState}: `appState` is durable across + * invocations, JSON-serializable, and deep-copied. `invocationState` is + * ephemeral and accepts arbitrary values. + * + * Excluded from `toJSON()` on {@link AgentResult} and all hook events because + * values may not be serializable; callers produce a serialized form explicitly + * if needed. + */ +export type InvocationState = Record + +/** + * Options for a single agent invocation. + */ +export interface InvokeOptions { + /** + * Zod schema for structured output validation, overriding the constructor-provided schema for this invocation only. + */ + structuredOutputSchema?: z.ZodSchema + + /** + * Per-invocation state. Passed to lifecycle hook events and tools, and + * returned on {@link AgentResult.invocationState}. Mutable — hooks and tools + * may read and write. See {@link InvocationState} for details. + * + * Defaults to an empty object when omitted. + */ + invocationState?: InvocationState + + /** + * External AbortSignal for cancelling the agent invocation. + * + * Use this when cancellation is driven by something outside the agent — for example, + * a client disconnect, a framework-managed request lifecycle, or a declarative timeout. + * The agent composes this signal with its own internal controller, so both + * `agent.cancel()` and this signal can trigger cancellation independently. + * + * When the signal fires, the agent stops at the next cancellation checkpoint and + * returns an AgentResult with `stopReason: 'cancelled'`. See + * {@link LocalAgent.cancelSignal} for how tools can participate in cancellation. + * + * @example + * ```typescript + * // Timeout-based cancellation + * const result = await agent.invoke('Hello', { + * cancelSignal: AbortSignal.timeout(5000), + * }) + * + * // Framework-driven cancellation (e.g., client disconnect) + * app.post('/chat', async (req, res) => { + * const result = await agent.invoke(req.body.message, { + * cancelSignal: req.signal, + * }) + * res.json(result) + * }) + * ``` + */ + cancelSignal?: AbortSignal +} + +/** + * Interface for agents that support request-response invocation. + * + * Both `Agent` (full orchestration agent) and `A2AAgent` (remote agent proxy) + * implement this interface, enabling polymorphic usage across the SDK. + */ +export interface InvokableAgent { + /** + * The unique identifier of the agent instance. + */ + readonly id: string + + /** + * The name of the agent. + */ + readonly name?: string + + /** + * Optional description of what the agent does. + */ + readonly description?: string + + /** + * Invokes the agent and returns the final result. + * + * @param args - Arguments for invoking the agent + * @param options - Optional invocation options (e.g. structured output schema) + * @returns Promise that resolves to the final AgentResult + */ + invoke(args: InvokeArgs, options?: InvokeOptions): Promise + + /** + * Streams the agent execution, yielding events and returning the final result. + * + * @param args - Arguments for invoking the agent + * @param options - Optional invocation options (e.g. structured output schema) + * @returns Async generator that yields stream events and returns AgentResult + */ + stream(args: InvokeArgs, options?: InvokeOptions): AsyncGenerator +} + +/** + * Branded symbol that prevents external implementations of {@link LocalAgent}. + * + * @internal + */ +export declare const localAgentSymbol: unique symbol + +/** + * Interface for agents with locally accessible state, messages, tools, and hooks. + * + * This interface is exported for typing purposes only (e.g. in {@link ToolContext}, + * hook events, and {@link Plugin.initAgent}). The Strands SDK is responsible for + * providing all implementations. External code should not implement this interface. + * + * @internal Not for external implementation. Use the {@link Agent} class instead. + */ +export interface LocalAgent { + /** @internal Prevents external implementations of this interface. */ + readonly [localAgentSymbol]: true + + /** + * The unique identifier of the agent instance. + */ + readonly id: string + + /** + * App state storage accessible to tools and application logic. + */ + appState: StateStore + + /** + * The conversation history of messages between user and assistant. + */ + messages: Message[] + + /** + * Runtime state for the model provider. Used by stateful models to persist + * provider-specific data (e.g., response IDs for server-side conversation chaining) + * across invocations. + */ + modelState: StateStore + + /** + * The tool registry for registering tools with the agent. + */ + readonly toolRegistry: ToolRegistry + + /** + * The model provider used by the agent for inference. + */ + readonly model: Model + + /** + * The system prompt to pass to the model provider. + */ + systemPrompt?: SystemPrompt + + /** + * The cancellation signal for the current invocation. + * + * Cancellation in the SDK is **cooperative**. The agent checks for cancellation at + * built-in checkpoints (between loop cycles, during model streaming, and between + * sequential tool executions), but once a tool callback is running, only the tool + * itself can respond to cancellation. There are two patterns: + * + * **Polling** — check `cancelSignal.aborted` between steps in a loop: + * ```ts + * callback: async ({ items }, context) => { + * const results = [] + * for (const item of items) { + * if (context.agent.cancelSignal.aborted) return results + * results.push(await process(item)) + * } + * return results + * } + * ``` + * + * **Signal forwarding** — pass to APIs that accept `AbortSignal`: + * ```ts + * callback: async ({ url }, context) => { + * const res = await fetch(url, { signal: context.agent.cancelSignal }) + * return res.text() + * } + * ``` + * + * If a tool does neither, it will run to completion even after cancellation is + * requested. The agent will resume cancellation handling after the tool returns. + * + * The cancelSignal can also be utilized in hook callbacks. + */ + readonly cancelSignal: AbortSignal + + /** + * Register a hook callback for a specific event type. + * + * Hooks execute in order from lowest to highest. Lower values always run + * first, on both Before* and After* events. Within the same order, After* + * events reverse registration order for cleanup symmetry. + * + * @param eventType - The event class constructor to register the callback for + * @param callback - The callback function to invoke when the event occurs + * @param options - Optional configuration including execution order + * @returns Cleanup function that removes the callback when invoked + */ + addHook( + eventType: HookableEventConstructor, + callback: HookCallback, + options?: HookCallbackOptions + ): HookCleanup + + /** + * Captures a point-in-time snapshot of the agent's current state. + * + * @param options - Controls which fields to capture and optional app data to store + * @returns A Snapshot containing the captured agent state + */ + takeSnapshot(options: TakeSnapshotOptions): Snapshot + + /** + * Restores agent state from a previously captured snapshot. + * + * Only fields present in `snapshot.data` are restored; absent fields are left unchanged. + * + * @param snapshot - The snapshot to restore from + */ + loadSnapshot(snapshot: Snapshot): void +} + +/** + * Result returned by the agent loop. + */ +export class AgentResult { + readonly type = 'agentResult' as const + + /** + * The stop reason from the final model response. + */ + readonly stopReason: StopReason + + /** + * The last message added to the messages array. + */ + readonly lastMessage: Message + + /** + * Local execution traces collected during the agent invocation. + * Contains timing and hierarchy of operations within the agent loop. + */ + readonly traces?: AgentTrace[] + + /** + * The validated structured output from the LLM, if a schema was provided. + * Type represents any validated Zod schema output. + */ + readonly structuredOutput?: z.output + + /** + * Aggregated metrics for the agent's loop execution. + * Tracks cycle counts, token usage, tool execution stats, and model latency. + */ + readonly metrics?: AgentMetrics + + /** + * Per-invocation state passed into the agent, threaded through hooks and + * tools, and surfaced here at the end of the invocation. See + * {@link InvocationState} for details. Always defined — defaults to `{}` when + * no `invocationState` was provided in {@link InvokeOptions}. + */ + readonly invocationState: InvocationState + + /** + * Interrupts that caused the agent to stop, when `stopReason` is `'interrupt'`. + * Contains the unanswered interrupts that require human input to resume. + */ + readonly interrupts?: Interrupt[] + + constructor(data: { + stopReason: StopReason + lastMessage: Message + invocationState: InvocationState + traces?: AgentTrace[] + metrics?: AgentMetrics + structuredOutput?: z.output + interrupts?: Interrupt[] + }) { + this.stopReason = data.stopReason + this.lastMessage = data.lastMessage + this.invocationState = data.invocationState + if (data.traces !== undefined) { + this.traces = data.traces + } + if (data.metrics !== undefined) { + this.metrics = data.metrics + } + if (data.structuredOutput !== undefined) { + this.structuredOutput = data.structuredOutput + } + if (data.interrupts !== undefined) { + this.interrupts = data.interrupts + } + } + + /** + * The most recent input token count from the last model invocation. + * Convenience accessor that delegates to `metrics.latestContextSize`. + * Returns `undefined` when no metrics or invocations are available. + */ + get contextSize(): number | undefined { + return this.metrics?.latestContextSize + } + + /** + * Projected context size for the next model call (inputTokens + outputTokens from the last call). + * Convenience accessor that delegates to `metrics.projectedContextSize`. + * Returns `undefined` when no metrics or invocations are available. + */ + get projectedContextSize(): number | undefined { + return this.metrics?.projectedContextSize + } + + /** + * Custom JSON serialization that excludes traces, metrics, and invocationState. + * Traces and metrics are excluded to avoid sending large payloads over the wire + * in API responses; `invocationState` is excluded because its values are + * caller-owned and may not be serializable (see {@link InvocationState}). + * + * All three remain accessible via their properties for debugging. + * + * @returns Object representation for safe serialization + */ + public toJSON(): object { + return { + type: this.type, + stopReason: this.stopReason, + lastMessage: this.lastMessage, + ...(this.structuredOutput !== undefined && { structuredOutput: this.structuredOutput }), + } + } + + /** + * Extracts a string representation of the result. + * + * Priority order: + * 1. `interrupts` serialized as JSON, if any are present + * 2. `structuredOutput` serialized as JSON + * 3. Text from `textBlock`, `reasoningBlock`, and `citationsBlock` content blocks + * + * @returns String representation of the result: JSON for interrupts/structuredOutput, or text content joined by newlines. + */ + public toString(): string { + if (this.interrupts && this.interrupts.length > 0) { + return JSON.stringify(this.interrupts) + } + + if (this.structuredOutput !== undefined) { + return JSON.stringify(this.structuredOutput) + } + + const textParts: string[] = [] + + for (const block of this.lastMessage.content) { + switch (block.type) { + case 'textBlock': + textParts.push(block.text) + break + case 'reasoningBlock': + if (block.text) { + // Add indentation to reasoning content + const indentedText = block.text.replace(/\n/g, '\n ') + textParts.push(`💭 Reasoning:\n ${indentedText}`) + } + break + case 'citationsBlock': + for (const c of block.content) { + if ('text' in c) { + textParts.push(c.text) + } + } + break + default: + console.debug(`Skipping content block type: ${block.type}`) + break + } + } + + return textParts.join('\n') + } +} + +/** + * Union type representing all possible streaming events from an agent. + * This includes model events, tool events, and agent-specific lifecycle events. + * + * This is a discriminated union where each event has a unique type field, + * allowing for type-safe event handling using switch statements. + * + * Every member extends {@link HookableEvent} (which extends {@link StreamEvent}), + * making all events both streamable and subscribable via hook callbacks. + * Raw data objects from lower layers (model, tools) should be wrapped + * in a StreamEvent subclass at the agent boundary rather than added directly. + */ +export type AgentStreamEvent = + | ModelStreamUpdateEvent + | ContentBlockEvent + | ModelMessageEvent + | ToolStreamUpdateEvent + | ToolResultEvent + | BeforeInvocationEvent + | AfterInvocationEvent + | BeforeModelCallEvent + | AfterModelCallEvent + | BeforeToolsEvent + | AfterToolsEvent + | BeforeToolCallEvent + | AfterToolCallEvent + | MessageAddedEvent + | InterruptEvent + | AgentResultEvent diff --git a/strands-ts/src/types/citations.ts b/strands-ts/src/types/citations.ts new file mode 100644 index 0000000000..a4e1b8b7cc --- /dev/null +++ b/strands-ts/src/types/citations.ts @@ -0,0 +1,218 @@ +import type { JSONSerializable, Serialized } from './json.js' + +/** + * Citation types for document citation content blocks. + * + * Citations are returned by models when document citations are enabled. + * They are output-only blocks that appear in conversation history. + */ + +/** + * Discriminated union of citation location types. + * Each variant uses a `type` field to identify the location kind. + */ +export type CitationLocation = + | { + /** + * Location referencing character positions within a document. + */ + type: 'documentChar' + + /** + * Index of the source document. + */ + documentIndex: number + + /** + * Start character position. + */ + start: number + + /** + * End character position. + */ + end: number + } + | { + /** + * Location referencing page positions within a document. + */ + type: 'documentPage' + + /** + * Index of the source document. + */ + documentIndex: number + + /** + * Start page number. + */ + start: number + + /** + * End page number. + */ + end: number + } + | { + /** + * Location referencing chunk positions within a document. + */ + type: 'documentChunk' + + /** + * Index of the source document. + */ + documentIndex: number + + /** + * Start chunk index. + */ + start: number + + /** + * End chunk index. + */ + end: number + } + | { + /** + * Location referencing a search result. + */ + type: 'searchResult' + + /** + * Index of the search result. + */ + searchResultIndex: number + + /** + * Start position within the search result. + */ + start: number + + /** + * End position within the search result. + */ + end: number + } + | { + /** + * Location referencing a web URL. + */ + type: 'web' + + /** + * The URL of the web source. + */ + url: string + + /** + * The domain of the web source. + */ + domain?: string + } + +/** + * Source content referenced by a citation. + * Modeled as a union type for future extensibility. + */ +export type CitationSourceContent = { text: string } + +/** + * Generated content associated with a citation. + * Modeled as a union type for future extensibility. + */ +export type CitationGeneratedContent = { text: string } + +/** + * A single citation linking generated content to a source location. + */ +export interface Citation { + /** + * The location of the cited source. + */ + location: CitationLocation + + /** + * The source identifier string. + */ + source: string + + /** + * The source content referenced by this citation. + */ + sourceContent: CitationSourceContent[] + + /** + * Title of the cited source. + */ + title: string +} + +/** + * Data for a citations content block. + */ +export interface CitationsBlockData { + /** + * Array of citations linking generated content to source locations. + */ + citations: Citation[] + + /** + * The generated content associated with these citations. + */ + content: CitationGeneratedContent[] +} + +/** + * Citations content block within a message. + * Returned by models when document citations are enabled. + * This is an output-only block — users do not construct these directly. + */ +export class CitationsBlock + implements CitationsBlockData, JSONSerializable<{ citations: Serialized }> +{ + /** + * Discriminator for citations content. + */ + readonly type = 'citationsBlock' as const + + /** + * Array of citations linking generated content to source locations. + */ + readonly citations: Citation[] + + /** + * The generated content associated with these citations. + */ + readonly content: CitationGeneratedContent[] + + constructor(data: CitationsBlockData) { + this.citations = data.citations + this.content = data.content + } + + /** + * Serializes the CitationsBlock to a JSON-compatible ContentBlockData object. + * Called automatically by JSON.stringify(). + */ + toJSON(): { citations: Serialized } { + return { + citations: { + citations: this.citations, + content: this.content, + }, + } + } + + /** + * Creates a CitationsBlock instance from its wrapped data format. + * + * @param data - Wrapped CitationsBlockData to deserialize + * @returns CitationsBlock instance + */ + static fromJSON(data: { citations: Serialized }): CitationsBlock { + return new CitationsBlock(data.citations) + } +} diff --git a/strands-ts/src/types/elicitation.ts b/strands-ts/src/types/elicitation.ts new file mode 100644 index 0000000000..99522b0245 --- /dev/null +++ b/strands-ts/src/types/elicitation.ts @@ -0,0 +1,21 @@ +import type { + ElicitResult, + ElicitRequestParams, + ClientRequest, + ClientNotification, +} from '@modelcontextprotocol/sdk/types.js' +import type { RequestHandlerExtra } from '@modelcontextprotocol/sdk/shared/protocol.js' + +/** + * Context provided to an elicitation callback, including the abort signal for the in-flight request. + */ +export type ElicitationContext = RequestHandlerExtra + +/** + * Callback invoked when an MCP server sends an elicitation request to gather user input during tool execution. + * + * @param context - Request context including abort signal. + * @param params - The elicitation parameters from the server (message, requested schema or URL). + * @returns The user's response: accept (with content), decline, or cancel. + */ +export type ElicitationCallback = (context: ElicitationContext, params: ElicitRequestParams) => Promise diff --git a/strands-ts/src/types/interrupt.ts b/strands-ts/src/types/interrupt.ts new file mode 100644 index 0000000000..97066362b4 --- /dev/null +++ b/strands-ts/src/types/interrupt.ts @@ -0,0 +1,132 @@ +/** + * Interrupt-related type definitions for human-in-the-loop workflows. + * + * These types define the data structures used when invoking agents with + * interrupt responses to resume execution. + */ + +import type { JSONValue } from './json.js' +import type { JSONSerializable } from './json.js' + +/** + * Parameters for raising an interrupt. + */ +export interface InterruptParams { + /** + * User-defined name for the interrupt. + * Must be unique within a single hook callback or tool execution. + */ + name: string + + /** + * User-provided reason for the interrupt. + */ + reason?: JSONValue + + /** + * Preemptive response to use if available. + * When provided, the interrupt returns this value immediately without + * halting agent execution. Useful for session-managed trust responses + * where a previous user response can be reused. + * + * @example + * ```typescript + * // If user already approved in a previous session, skip the interrupt + * const approval = context.interrupt({ + * name: 'confirm_delete', + * reason: 'Confirm deletion?', + * response: agent.appState['savedApproval'], + * }) + * ``` + */ + response?: JSONValue +} + +/** + * User response to an interrupt. + */ +export interface InterruptResponse { + /** + * Unique identifier of the interrupt being responded to. + */ + interruptId: string + + /** + * User's response to the interrupt. + */ + response: JSONValue +} + +/** + * Data format for a content block containing a user response to an interrupt. + */ +export interface InterruptResponseContentData { + /** + * The interrupt response data. + */ + interruptResponse: InterruptResponse +} + +/** + * Content block containing a user response to an interrupt. + * Used when invoking an agent to resume from an interrupted state. + * + * @example + * ```typescript + * const content = new InterruptResponseContent({ + * interruptId: interrupt.id, + * response: 'approved', + * }) + * ``` + */ +export class InterruptResponseContent + implements InterruptResponseContentData, JSONSerializable +{ + /** + * Discriminator for interrupt response content blocks. + */ + readonly type = 'interruptResponseContent' as const + + /** + * The interrupt response data. + */ + readonly interruptResponse: InterruptResponse + + constructor(data: InterruptResponse) { + this.interruptResponse = data + } + + /** + * Serializes to a JSON-compatible {@link InterruptResponseContentData} object. + * Called automatically by `JSON.stringify()`. + */ + toJSON(): InterruptResponseContentData { + return { interruptResponse: this.interruptResponse } + } + + /** + * Creates an InterruptResponseContent instance from data. + * + * @param data - Data to deserialize + * @returns InterruptResponseContent instance + */ + static fromJSON(data: InterruptResponseContentData): InterruptResponseContent { + return new InterruptResponseContent(data.interruptResponse) + } +} + +/** + * Type guard that checks whether a value is an {@link InterruptResponseContent}. + * + * @internal + */ +export function isInterruptResponseContent(value: unknown): value is InterruptResponseContent { + if (value instanceof InterruptResponseContent) { + return true + } + if (typeof value !== 'object' || value === null || !('interruptResponse' in value)) { + return false + } + const { interruptResponse } = value as InterruptResponseContentData + return typeof interruptResponse === 'object' && interruptResponse !== null && 'interruptId' in interruptResponse +} diff --git a/strands-ts/src/types/json.ts b/strands-ts/src/types/json.ts new file mode 100644 index 0000000000..b0af105f49 --- /dev/null +++ b/strands-ts/src/types/json.ts @@ -0,0 +1,191 @@ +import type { JSONSchema7 } from 'json-schema' +/** + * Interface for objects that can be serialized to JSON via `toJSON()`. + * + * @typeParam T - The type returned by `toJSON()`. + */ +export interface JSONSerializable { + toJSON(): T +} + +import { JsonValidationError } from '../errors.js' + +/** + * Represents any valid JSON value. + * This type ensures type safety for JSON-serializable data. + * + * @example + * ```typescript + * const value: JSONValue = { key: 'value', nested: { arr: [1, 2, 3] } } + * const text: JSONValue = 'hello' + * const num: JSONValue = 42 + * const bool: JSONValue = true + * const nothing: JSONValue = null + * ``` + */ +export type JSONValue = string | number | boolean | null | { [key: string]: JSONValue } | JSONValue[] + +/** + * Represents a JSON Schema definition. + * Used for defining the structure of tool inputs and outputs. + * + * This is based on JSON Schema Draft 7 specification. + * + * @example + * ```typescript + * const schema: JSONSchema = { + * type: 'object', + * properties: { + * name: { type: 'string' }, + * age: { type: 'number' } + * }, + * required: ['name'] + * } + * ``` + */ +export type JSONSchema = JSONSchema7 + +/** + * Creates a deep copy of a value using JSON serialization. + * + * @param value - The value to copy + * @returns A deep copy of the value + * @throws Error if the value cannot be JSON serialized + */ +export function deepCopy(value: unknown): JSONValue { + try { + return JSON.parse(JSON.stringify(value)) as JSONValue + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error) + throw new Error(`Unable to serialize tool result: ${errorMessage}`) + } +} + +/** + * Creates a deep copy of a value with explicit validation for non-serializable types. + * Uses JSON.stringify's replacer to detect and report non-serializable values with path information. + * + * @param value - The value to copy + * @param contextPath - Context path for error messages (e.g., 'initialState', 'value for key "config"') + * @returns A deep copy of the value + * @throws JsonValidationError if value contains functions, symbols, or undefined values + */ +export function deepCopyWithValidation(value: unknown, contextPath: string = 'value'): JSONValue { + const pathStack: string[] = [] + + const replacer = (key: string, val: unknown): unknown => { + // Build current path + let currentPath = contextPath + if (key !== '') { + // Check if parent is array (numeric key pattern) + const isArrayIndex = /^\d+$/.test(key) + if (isArrayIndex) { + currentPath = pathStack.length > 0 ? `${pathStack[pathStack.length - 1]}[${key}]` : `${contextPath}[${key}]` + } else { + currentPath = pathStack.length > 0 ? `${pathStack[pathStack.length - 1]}.${key}` : `${contextPath}.${key}` + } + } + + // Check for non-serializable types + if (typeof val === 'function') { + throw new JsonValidationError(`${currentPath} contains a function which cannot be serialized`) + } + + if (typeof val === 'symbol') { + throw new JsonValidationError(`${currentPath} contains a symbol which cannot be serialized`) + } + + if (val === undefined) { + throw new JsonValidationError(`${currentPath} is undefined which cannot be serialized`) + } + + // Track path for nested objects/arrays + if (val !== null && typeof val === 'object') { + pathStack.push(currentPath) + } + + return val + } + + try { + const serialized = JSON.stringify(value, replacer) + return JSON.parse(serialized) as JSONValue + } catch (error) { + // If it's our validation error, re-throw it + if (error instanceof JsonValidationError) { + throw error + } + // Otherwise, wrap it + const errorMessage = error instanceof Error ? error.message : String(error) + throw new Error(`Unable to serialize value: ${errorMessage}`) + } +} + +/** + * Removes undefined values from an object. + * Useful for JSON serialization to avoid including undefined fields in output. + * + * @param obj - Object with potentially undefined values + * @returns New object with undefined values removed + * + * @example + * ```typescript + * const data = { name: 'test', value: undefined, count: 0 } + * const clean = omitUndefined(data) + * // Result: { name: 'test', count: 0 } + * ``` + */ +export function omitUndefined(obj: T): { [K in keyof T]: Exclude } { + const result = {} as { [K in keyof T]: Exclude } + for (const [key, value] of Object.entries(obj)) { + if (value !== undefined) { + ;(result as Record)[key] = value + } + } + return result +} +/** + * Recursively transforms a type by converting all Uint8Array properties to strings. + * Used for JSON serialization where binary data is encoded as base64 strings. + * + * @example + * ```typescript + * interface Data { + * name: string + * bytes: Uint8Array + * nested: { content: Uint8Array } + * } + * + * type SerializedData = Serialized + * // Result: { name: string; bytes: string; nested: { content: string } } + * ``` + */ +export type Serialized = T extends Uint8Array + ? string + : T extends (infer U)[] + ? Serialized[] + : T extends object + ? { [K in keyof T]: Serialized } + : T + +/** + * Represents data that may contain either Uint8Array (runtime) or string (serialized) for binary fields. + * Used for deserialization where input may come from JSON (strings) or direct construction (Uint8Array). + * + * @example + * ```typescript + * interface Data { + * bytes: Uint8Array + * } + * + * type InputData = MaybeSerializedInput + * // Result: { bytes: Uint8Array | string } + * ``` + */ +export type MaybeSerializedInput = T extends Uint8Array + ? Uint8Array | string + : T extends (infer U)[] + ? MaybeSerializedInput[] + : T extends object + ? { [K in keyof T]: MaybeSerializedInput } + : T diff --git a/strands-ts/src/types/lifecycle-observer.ts b/strands-ts/src/types/lifecycle-observer.ts new file mode 100644 index 0000000000..80439f610d --- /dev/null +++ b/strands-ts/src/types/lifecycle-observer.ts @@ -0,0 +1,19 @@ +import type { LocalAgent } from './agent.js' + +/** + * Implementors are given the agent at registration time so they can subscribe + * to hook events of their choice via {@link LocalAgent.addHook}. This is the + * extension point for components that need to observe arbitrary lifecycle + * events. Each observer method is optional — implementors define only the + * surfaces they care about, and the agent probes for each at registration. + */ +export interface LifecycleObserver { + /** Stable identifier for this observer. Used for logging and duplicate detection. */ + readonly name: string + + /** + * Called once when the observer is registered with an agent. Implementations + * typically subscribe to one or more events via `agent.addHook`. + */ + observeAgent?(agent: LocalAgent): void | Promise +} diff --git a/strands-ts/src/types/media.ts b/strands-ts/src/types/media.ts new file mode 100644 index 0000000000..987319ba7b --- /dev/null +++ b/strands-ts/src/types/media.ts @@ -0,0 +1,559 @@ +/** + * Media and document content types for multimodal AI interactions. + * + * This module provides types for handling images, videos, and documents + * with support for multiple sources (bytes, S3, URLs, files). + */ + +import type { Serialized, MaybeSerializedInput, JSONSerializable } from './json.js' +import { omitUndefined } from './json.js' +import { TextBlock, type TextBlockData } from './messages.js' + +export type { ImageFormat, VideoFormat, DocumentFormat, MediaFormat } from '../mime.js' +import type { ImageFormat, VideoFormat, DocumentFormat } from '../mime.js' + +/** + * Cross-platform base64 encoding function that works in both browser and Node.js environments. + */ +export function encodeBase64(input: string | Uint8Array): string { + // Handle Uint8Array (Image/PDF bytes) + if (input instanceof Uint8Array) { + // Node.js: Fast and zero copy + if (typeof globalThis.Buffer === 'function') { + return globalThis.Buffer.from(input).toString('base64') + } + + // Browser: Safe conversion which doesn't cause a stack overflow like when using the spread operator. + // We convert bytes to binary string in chunks to satisfy btoa() + const CHUNK_SIZE = 0x8000 // 32k chunks + let binary = '' + for (let i = 0; i < input.length; i += CHUNK_SIZE) { + binary += String.fromCharCode.apply( + null, + input.subarray(i, Math.min(i + CHUNK_SIZE, input.length)) as unknown as number[] + ) + } + + return globalThis.btoa(binary) + } + + if (typeof globalThis.btoa === 'function') { + return globalThis.btoa(input) + } + + return globalThis.Buffer.from(input, 'binary').toString('base64') +} + +/** + * Cross-platform base64 decoding function that works in both browser and Node.js environments. + * + * @param input - Base64 encoded string to decode + * @returns Decoded bytes as Uint8Array + */ +export function decodeBase64(input: string): Uint8Array { + // Node.js: Fast path using Buffer + if (typeof globalThis.Buffer === 'function') { + return new Uint8Array(globalThis.Buffer.from(input, 'base64')) + } + + // Browser: Use atob to decode base64 to binary string, then convert to bytes + const binary = globalThis.atob(input) + const bytes = new Uint8Array(binary.length) + for (let i = 0; i < binary.length; i++) { + bytes[i] = binary.charCodeAt(i) + } + return bytes +} + +/** + * Base interface for a document/media source location. + */ +export interface LocationData { + /** + * Location type discriminator. + */ + type: string +} + +/** + * Data for an S3 location. + */ +export interface S3LocationData extends LocationData { + /** + * Location type — always "s3". + */ + type: 's3' + + /** + * S3 URI in format: s3://bucket-name/key-name + */ + uri: string + + /** + * AWS account ID of the S3 bucket owner (12-digit). + * Required if the bucket belongs to another AWS account. + */ + bucketOwner?: string +} + +/** + * S3 location for media and document sources. + */ +export class S3Location implements S3LocationData, JSONSerializable { + readonly type = 's3' as const + readonly uri: string + readonly bucketOwner?: string + + constructor(data: Omit & { type?: 's3' }) { + this.uri = data.uri + if (data.bucketOwner !== undefined) { + this.bucketOwner = data.bucketOwner + } + } + + /** + * Serializes the S3Location to a JSON-compatible S3LocationData object. + * Called automatically by JSON.stringify(). + */ + toJSON(): S3LocationData { + return omitUndefined({ + type: this.type, + uri: this.uri, + bucketOwner: this.bucketOwner, + }) as S3LocationData + } + + /** + * Creates an S3Location instance from S3LocationData. + * + * @param data - S3LocationData to deserialize + * @returns S3Location instance + */ + static fromJSON(data: S3LocationData): S3Location { + return new S3Location(data) + } +} + +/** + * Source for an image (Data version). + * Supports multiple formats for different providers. + */ +export type ImageSourceData = + | { bytes: Uint8Array } // raw binary data + | { location: S3LocationData } // remote location reference + | { url: string } // https:// + +/** + * Source for an image (Class version). + */ +export type ImageSource = + | { type: 'imageSourceBytes'; bytes: Uint8Array } + | { type: 'imageSourceS3Location'; location: S3Location } + | { type: 'imageSourceUrl'; url: string } + +/** + * Data for an image block. + */ +export interface ImageBlockData { + /** + * Image format. + */ + format: ImageFormat + + /** + * Image source. + */ + source: ImageSourceData +} + +/** + * Image content block. + */ +export class ImageBlock implements ImageBlockData, JSONSerializable<{ image: Serialized }> { + /** + * Discriminator for image content. + */ + readonly type = 'imageBlock' as const + + /** + * Image format. + */ + readonly format: ImageFormat + + /** + * Image source. + */ + readonly source: ImageSource + + constructor(data: ImageBlockData) { + this.format = data.format + this.source = this._convertSource(data.source) + } + + private _convertSource(source: ImageSourceData): ImageSource { + if ('bytes' in source) { + return { + type: 'imageSourceBytes', + bytes: source.bytes, + } + } + if ('url' in source) { + return { + type: 'imageSourceUrl', + url: source.url, + } + } + if ('location' in source) { + return { + type: 'imageSourceS3Location', + location: new S3Location(source.location), + } + } + throw new Error('Invalid image source') + } + + /** + * Serializes the ImageBlock to a JSON-compatible ContentBlockData object. + * Called automatically by JSON.stringify(). + * Uint8Array bytes are encoded as base64 string. + */ + toJSON(): { image: Serialized } { + let source: Serialized + if (this.source.type === 'imageSourceBytes') { + source = { bytes: encodeBase64(this.source.bytes) } + } else if (this.source.type === 'imageSourceUrl') { + source = { url: this.source.url } + } else { + source = { location: this.source.location.toJSON() } + } + return { + image: { + format: this.format, + source, + }, + } + } + + /** + * Creates an ImageBlock instance from its wrapped data format. + * Base64-encoded bytes are decoded back to Uint8Array. + * + * @param data - Wrapped ImageBlockData to deserialize (accepts both string and Uint8Array for bytes) + * @returns ImageBlock instance + */ + static fromJSON(data: { image: MaybeSerializedInput }): ImageBlock { + const image = data.image + let source: ImageSourceData + if ('bytes' in image.source) { + const bytes = image.source.bytes + source = { bytes: typeof bytes === 'string' ? decodeBase64(bytes) : bytes } + } else if ('url' in image.source) { + source = { url: image.source.url } + } else { + source = { location: image.source.location } + } + return new ImageBlock({ + format: image.format, + source, + }) + } +} + +/** + * Source for a video (Data version). + */ +export type VideoSourceData = { bytes: Uint8Array } | { location: S3LocationData } // remote location reference + +/** + * Source for a video (Class version). + */ +export type VideoSource = + | { type: 'videoSourceBytes'; bytes: Uint8Array } + | { type: 'videoSourceS3Location'; location: S3Location } + +/** + * Data for a video block. + */ +export interface VideoBlockData { + /** + * Video format. + */ + format: VideoFormat + + /** + * Video source. + */ + source: VideoSourceData +} + +/** + * Video content block. + */ +export class VideoBlock implements VideoBlockData, JSONSerializable<{ video: Serialized }> { + /** + * Discriminator for video content. + */ + readonly type = 'videoBlock' as const + + /** + * Video format. + */ + readonly format: VideoFormat + + /** + * Video source. + */ + readonly source: VideoSource + + constructor(data: VideoBlockData) { + this.format = data.format + this.source = this._convertSource(data.source) + } + + private _convertSource(source: VideoSourceData): VideoSource { + if ('bytes' in source) { + return { + type: 'videoSourceBytes', + bytes: source.bytes, + } + } + if ('location' in source) { + return { type: 'videoSourceS3Location', location: new S3Location(source.location) } + } + throw new Error('Invalid video source') + } + + /** + * Serializes the VideoBlock to a JSON-compatible ContentBlockData object. + * Called automatically by JSON.stringify(). + * Uint8Array bytes are encoded as base64 string. + */ + toJSON(): { video: Serialized } { + let source: Serialized + if (this.source.type === 'videoSourceBytes') { + source = { bytes: encodeBase64(this.source.bytes) } + } else { + source = { location: this.source.location.toJSON() } + } + return { + video: { + format: this.format, + source, + }, + } + } + + /** + * Creates a VideoBlock instance from its wrapped data format. + * Base64-encoded bytes are decoded back to Uint8Array. + * + * @param data - Wrapped VideoBlockData to deserialize (accepts both string and Uint8Array for bytes) + * @returns VideoBlock instance + */ + static fromJSON(data: { video: MaybeSerializedInput }): VideoBlock { + const video = data.video + let source: VideoSourceData + if ('bytes' in video.source) { + const bytes = video.source.bytes + source = { bytes: typeof bytes === 'string' ? decodeBase64(bytes) : bytes } + } else { + source = { location: video.source.location } + } + return new VideoBlock({ + format: video.format, + source, + }) + } +} + +/** + * Content blocks that can be nested inside a document. + * Documents can contain text blocks for structured content. + */ +export type DocumentContentBlockData = TextBlockData +export type DocumentContentBlock = TextBlock + +/** + * Source for a document (Data version). + * Supports multiple formats including structured content. + */ +export type DocumentSourceData = + | { bytes: Uint8Array } // raw binary data + | { text: string } // plain text + | { content: DocumentContentBlockData[] } // structured content + | { location: S3LocationData } // remote location reference + +/** + * Source for a document (Class version). + */ +export type DocumentSource = + | { type: 'documentSourceBytes'; bytes: Uint8Array } + | { type: 'documentSourceText'; text: string } + | { type: 'documentSourceContentBlock'; content: DocumentContentBlock[] } + | { type: 'documentSourceS3Location'; location: S3Location } + +/** + * Data for a document block. + */ +export interface DocumentBlockData { + /** + * Document name. + */ + name: string + + /** + * Document format. + */ + format: DocumentFormat + + /** + * Document source. + */ + source: DocumentSourceData + + /** + * Citation configuration. + */ + citations?: { enabled: boolean } + + /** + * Context information for the document. + */ + context?: string +} + +/** + * Document content block. + */ +export class DocumentBlock implements DocumentBlockData, JSONSerializable<{ document: Serialized }> { + /** + * Discriminator for document content. + */ + readonly type = 'documentBlock' as const + + /** + * Document name. + */ + readonly name: string + + /** + * Document format. + */ + readonly format: DocumentFormat + + /** + * Document source. + */ + readonly source: DocumentSource + + /** + * Citation configuration. + */ + readonly citations?: { enabled: boolean } + + /** + * Context information for the document. + */ + readonly context?: string + + constructor(data: DocumentBlockData) { + this.name = data.name + this.format = data.format + this.source = this._convertSource(data.source) + if (data.citations !== undefined) { + this.citations = data.citations + } + if (data.context !== undefined) { + this.context = data.context + } + } + + private _convertSource(source: DocumentSourceData): DocumentSource { + if ('bytes' in source) { + return { + type: 'documentSourceBytes', + bytes: source.bytes, + } + } + if ('text' in source) { + return { + type: 'documentSourceText', + text: source.text, + } + } + if ('content' in source) { + return { + type: 'documentSourceContentBlock', + content: source.content.map((block) => new TextBlock(block.text)), + } + } + if ('location' in source) { + return { + type: 'documentSourceS3Location', + location: new S3Location(source.location), + } + } + throw new Error('Invalid document source') + } + + /** + * Serializes the DocumentBlock to a JSON-compatible ContentBlockData object. + * Called automatically by JSON.stringify(). + * Uint8Array bytes are encoded as base64 string. + */ + toJSON(): { document: Serialized } { + let source: Serialized + if (this.source.type === 'documentSourceBytes') { + source = { bytes: encodeBase64(this.source.bytes) } + } else if (this.source.type === 'documentSourceText') { + source = { text: this.source.text } + } else if (this.source.type === 'documentSourceContentBlock') { + source = { content: this.source.content.map((block) => block.toJSON()) } + } else { + source = { location: this.source.location.toJSON() } + } + return { + document: omitUndefined({ + name: this.name, + format: this.format, + source, + citations: this.citations, + context: this.context, + }), + } + } + + /** + * Creates a DocumentBlock instance from its wrapped data format. + * Base64-encoded bytes are decoded back to Uint8Array. + * + * @param data - Wrapped DocumentBlockData to deserialize (accepts both string and Uint8Array for bytes) + * @returns DocumentBlock instance + */ + static fromJSON(data: { document: MaybeSerializedInput }): DocumentBlock { + const doc = data.document + let source: DocumentSourceData + if ('bytes' in doc.source) { + const bytes = doc.source.bytes + source = { bytes: typeof bytes === 'string' ? decodeBase64(bytes) : bytes } + } else if ('text' in doc.source) { + source = { text: doc.source.text } + } else if ('content' in doc.source) { + source = { content: doc.source.content } + } else { + source = { location: doc.source.location } + } + const result: DocumentBlockData = { + name: doc.name, + format: doc.format, + source, + } + if (doc.citations !== undefined) { + result.citations = doc.citations + } + if (doc.context !== undefined) { + result.context = doc.context + } + return new DocumentBlock(result) + } +} diff --git a/strands-ts/src/types/messages.ts b/strands-ts/src/types/messages.ts new file mode 100644 index 0000000000..a5bd8f24f4 --- /dev/null +++ b/strands-ts/src/types/messages.ts @@ -0,0 +1,946 @@ +import type { JSONValue, Serialized, MaybeSerializedInput, JSONSerializable } from './json.js' +import { omitUndefined } from './json.js' +import type { ImageBlockData, VideoBlockData, DocumentBlockData } from './media.js' +import { ImageBlock, VideoBlock, DocumentBlock, encodeBase64, decodeBase64 } from './media.js' +import type { CitationsBlockData } from './citations.js' +import { CitationsBlock } from './citations.js' +import type { Usage, Metrics } from '../models/streaming.js' + +/** + * Message types and content blocks for conversational AI interactions. + * + * This module follows a pattern where "Data" interfaces define the structure + * for objects, while corresponding classes extend those interfaces with additional + * functionality and type discrimination. + */ + +/** + * Optional metadata attached to a message. + * + * Not sent to model providers — model providers construct their own message format + * from `role` and `content` only. Persisted alongside the message in session storage. + */ +export interface MessageMetadata { + /** Token usage information from the model response. */ + usage?: Usage + /** Performance metrics from the model response. */ + metrics?: Metrics + /** Arbitrary user/framework metadata (e.g. compression provenance). */ + custom?: Record +} + +/** + * Data for a message. + */ +export interface MessageData { + /** + * The role of the message sender. + */ + role: Role + + /** + * Array of content blocks that make up this message. + */ + content: ContentBlockData[] + + /** + * Optional metadata, not sent to model providers. + */ + metadata?: MessageMetadata +} + +/** + * A message in a conversation between user and assistant. + * Each message has a role (user or assistant) and an array of content blocks. + */ +export class Message implements JSONSerializable { + /** + * Discriminator for message type. + */ + readonly type = 'message' as const + + /** + * The role of the message sender. + */ + readonly role: Role + + /** + * Array of content blocks that make up this message. + */ + readonly content: ContentBlock[] + + /** + * Optional metadata, not sent to model providers. + */ + readonly metadata?: MessageMetadata + + constructor(data: { role: Role; content: ContentBlock[]; metadata?: MessageMetadata }) { + this.role = data.role + this.content = data.content + if (data.metadata !== undefined) { + this.metadata = data.metadata + } + } + + /** + * Creates a Message instance from MessageData. + */ + public static fromMessageData(data: MessageData): Message { + const contentBlocks: ContentBlock[] = data.content.map(contentBlockFromData) + + return new Message({ + role: data.role, + content: contentBlocks, + ...(data.metadata !== undefined && { metadata: data.metadata }), + }) + } + + /** + * Serializes the Message to a JSON-compatible MessageData object. + * Called automatically by JSON.stringify(). + */ + toJSON(): MessageData { + return { + role: this.role, + content: this.content.map((block) => block.toJSON() as ContentBlockData), + ...(this.metadata !== undefined && { metadata: this.metadata }), + } + } + + /** + * Creates a Message instance from MessageData. + * Alias for fromMessageData for API consistency. + * + * @param data - MessageData to deserialize + * @returns Message instance + */ + static fromJSON(data: MessageData): Message { + return Message.fromMessageData(data) + } +} + +/** + * Role of a message in a conversation. + * Can be either 'user' (human input) or 'assistant' (model response). + */ +export type Role = 'user' | 'assistant' + +/** + * A block of content within a message. + * Content blocks can contain text, tool usage requests, tool results, reasoning content, cache points, guard content, or media (image, video, document). + * + * This is a discriminated union where the object key determines the content format. + * + * @example + * ```typescript + * if ('text' in block) { + * console.log(block.text.text) + * } + * ``` + */ +export type ContentBlockData = + | TextBlockData + | { toolUse: ToolUseBlockData } + | { toolResult: ToolResultBlockData } + | { reasoning: ReasoningBlockData } + | { cachePoint: CachePointBlockData } + | { guardContent: GuardContentBlockData } + | { image: ImageBlockData } + | { video: VideoBlockData } + | { document: DocumentBlockData } + | { citations: CitationsBlockData } + +export type ContentBlock = + | TextBlock + | ToolUseBlock + | ToolResultBlock + | ReasoningBlock + | CachePointBlock + | GuardContentBlock + | ImageBlock + | VideoBlock + | DocumentBlock + | CitationsBlock + +/** + * Data for a text block. + */ +export interface TextBlockData { + /** + * Plain text content. + */ + text: string +} + +/** + * Text content block within a message. + */ +export class TextBlock implements TextBlockData, JSONSerializable { + /** + * Discriminator for text content. + */ + readonly type = 'textBlock' as const + + /** + * Plain text content. + */ + readonly text: string + + constructor(data: string) { + this.text = data + } + + /** + * Serializes the TextBlock to a JSON-compatible TextBlockData object. + * Called automatically by JSON.stringify(). + */ + toJSON(): TextBlockData { + return { text: this.text } + } + + /** + * Creates a TextBlock instance from TextBlockData. + * + * @param data - TextBlockData to deserialize + * @returns TextBlock instance + */ + static fromJSON(data: TextBlockData): TextBlock { + return new TextBlock(data.text) + } +} + +/** + * Data for a tool use block. + */ +export interface ToolUseBlockData { + /** + * The name of the tool to execute. + */ + name: string + + /** + * Unique identifier for this tool use instance. + */ + toolUseId: string + + /** + * The input parameters for the tool. + * This can be any JSON-serializable value. + */ + input: JSONValue + + /** + * Reasoning signature from thinking models (e.g., Gemini). + * Must be preserved and sent back to the model for multi-turn tool use. + */ + reasoningSignature?: string +} + +/** + * Tool use content block. + */ +export class ToolUseBlock implements ToolUseBlockData, JSONSerializable<{ toolUse: ToolUseBlockData }> { + /** + * Discriminator for tool use content. + */ + readonly type = 'toolUseBlock' as const + + /** + * The name of the tool to execute. + */ + readonly name: string + + /** + * Unique identifier for this tool use instance. + */ + readonly toolUseId: string + + /** + * The input parameters for the tool. + * This can be any JSON-serializable value. + */ + readonly input: JSONValue + + /** + * Reasoning signature from thinking models (e.g., Gemini). + * Must be preserved and sent back to the model for multi-turn tool use. + */ + readonly reasoningSignature?: string + + constructor(data: ToolUseBlockData) { + this.name = data.name + this.toolUseId = data.toolUseId + this.input = data.input + if (data.reasoningSignature !== undefined) { + this.reasoningSignature = data.reasoningSignature + } + } + + /** + * Serializes the ToolUseBlock to a JSON-compatible ContentBlockData object. + * Called automatically by JSON.stringify(). + */ + toJSON(): { toolUse: ToolUseBlockData } { + return { + toolUse: omitUndefined({ + name: this.name, + toolUseId: this.toolUseId, + input: this.input, + reasoningSignature: this.reasoningSignature, + }), + } + } + + /** + * Creates a ToolUseBlock instance from its wrapped data format. + * + * @param data - Wrapped ToolUseBlockData to deserialize + * @returns ToolUseBlock instance + */ + static fromJSON(data: { toolUse: ToolUseBlockData }): ToolUseBlock { + return new ToolUseBlock(data.toolUse) + } +} + +/** + * Content within a tool result. + * Can be text, structured JSON data, or media blocks (image, video, document). + * + * This is a discriminated union where the object key determines the content format. + */ +export type ToolResultContentData = + | TextBlockData + | JsonBlockData + | { image: ImageBlockData } + | { video: VideoBlockData } + | { document: DocumentBlockData } + +export type ToolResultContent = TextBlock | JsonBlock | ImageBlock | VideoBlock | DocumentBlock + +/** + * Data for a tool result block. + */ +export interface ToolResultBlockData { + /** + * The ID of the tool use that this result corresponds to. + */ + toolUseId: string + + /** + * Status of the tool execution. + */ + status: 'success' | 'error' + + /** + * The content returned by the tool. + */ + content: ToolResultContentData[] + + /** + * The original error object when status is 'error'. + * Available for inspection by hooks, error handlers, and agent loop. + * Tools must wrap non-Error thrown values into Error objects. + */ + error?: Error +} + +/** + * Tool result content block. + */ +export class ToolResultBlock implements JSONSerializable<{ toolResult: ToolResultBlockData }> { + /** + * Discriminator for tool result content. + */ + readonly type = 'toolResultBlock' as const + + /** + * The ID of the tool use that this result corresponds to. + */ + readonly toolUseId: string + + /** + * Status of the tool execution. + */ + readonly status: 'success' | 'error' + + /** + * The content returned by the tool. + */ + readonly content: ToolResultContent[] + + /** + * The original error object when status is 'error'. + * Available for inspection by hooks, error handlers, and agent loop. + * Tools must wrap non-Error thrown values into Error objects. + */ + readonly error?: Error + + constructor(data: { toolUseId: string; status: 'success' | 'error'; content: ToolResultContent[]; error?: Error }) { + this.toolUseId = data.toolUseId + this.status = data.status + this.content = data.content + if (data.error !== undefined) { + this.error = data.error + } + } + + /** + * Serializes the ToolResultBlock to a JSON-compatible ContentBlockData object. + * Called automatically by JSON.stringify(). + * Note: The error field is not serialized (deferred for future implementation). + */ + toJSON(): { toolResult: ToolResultBlockData } { + return { + toolResult: { + toolUseId: this.toolUseId, + status: this.status, + content: this.content.map((block) => block.toJSON() as ToolResultContentData), + }, + } + } + + /** + * Creates a ToolResultBlock instance from its wrapped data format. + * + * @param data - Wrapped ToolResultBlockData to deserialize + * @returns ToolResultBlock instance + */ + static fromJSON(data: { toolResult: ToolResultBlockData }): ToolResultBlock { + const content = data.toolResult.content.map(toolResultContentFromData) + return new ToolResultBlock({ + toolUseId: data.toolResult.toolUseId, + status: data.toolResult.status, + content, + }) + } +} + +/** + * Converts a single ToolResultContentData to a ToolResultContent class instance. + * + * @param data - The tool result content data to convert + * @returns A ToolResultContent instance of the appropriate type + * @throws Error if the content data type is unknown + */ +export function toolResultContentFromData(data: ToolResultContentData): ToolResultContent { + if ('text' in data) return new TextBlock(data.text) + if ('json' in data) return new JsonBlock(data) + if ('image' in data) return ImageBlock.fromJSON(data as { image: ImageBlockData }) + if ('video' in data) return VideoBlock.fromJSON(data as { video: VideoBlockData }) + if ('document' in data) return DocumentBlock.fromJSON(data as { document: DocumentBlockData }) + throw new Error('Unknown ToolResultContentData type') +} + +/** + * Data for a reasoning block. + */ +export interface ReasoningBlockData { + /** + * The text content of the reasoning process. + */ + text?: string + + /** + * A cryptographic signature for verification purposes. + */ + signature?: string + + /** + * The redacted content of the reasoning process. + */ + redactedContent?: Uint8Array +} + +/** + * Reasoning content block within a message. + */ +export class ReasoningBlock + implements ReasoningBlockData, JSONSerializable<{ reasoning: Serialized }> +{ + /** + * Discriminator for reasoning content. + */ + readonly type = 'reasoningBlock' as const + + /** + * The text content of the reasoning process. + */ + readonly text?: string + + /** + * A cryptographic signature for verification purposes. + */ + readonly signature?: string + + /** + * The redacted content of the reasoning process. + */ + readonly redactedContent?: Uint8Array + + constructor(data: ReasoningBlockData) { + if (data.text !== undefined) { + this.text = data.text + } + if (data.signature !== undefined) { + this.signature = data.signature + } + if (data.redactedContent !== undefined) { + this.redactedContent = data.redactedContent + } + } + + /** + * Serializes the ReasoningBlock to a JSON-compatible ContentBlockData object. + * Called automatically by JSON.stringify(). + * Uint8Array redactedContent is encoded as base64 string. + */ + toJSON(): { reasoning: Serialized } { + return { + reasoning: omitUndefined({ + text: this.text, + signature: this.signature, + redactedContent: this.redactedContent ? encodeBase64(this.redactedContent) : undefined, + }), + } + } + + /** + * Creates a ReasoningBlock instance from its wrapped data format. + * Base64-encoded redactedContent is decoded back to Uint8Array. + * + * @param data - Wrapped ReasoningBlockData to deserialize (accepts both string and Uint8Array for redactedContent) + * @returns ReasoningBlock instance + */ + static fromJSON(data: { reasoning: MaybeSerializedInput }): ReasoningBlock { + const reasoning = data.reasoning + const result: ReasoningBlockData = {} + if (reasoning.text !== undefined) { + result.text = reasoning.text + } + if (reasoning.signature !== undefined) { + result.signature = reasoning.signature + } + if (reasoning.redactedContent !== undefined) { + result.redactedContent = + typeof reasoning.redactedContent === 'string' + ? decodeBase64(reasoning.redactedContent) + : reasoning.redactedContent + } + return new ReasoningBlock(result) + } +} + +/** + * Data for a cache point block. + */ +export interface CachePointBlockData { + /** + * The cache type. Currently only 'default' is supported. + */ + cacheType: 'default' + + /** + * Optional TTL for the cache entry. When omitted, the provider's default TTL is used. + * + * The accepted value space is provider-specific. For example, the Bedrock provider only + * accepts the values defined by `BedrockCacheTTL` (`'5m'` and `'1h'`). Other providers + * may accept different values or ignore this field. + */ + ttl?: string +} + +/** + * Cache point block for prompt caching. + * Marks a position in a message or system prompt where caching should occur. + */ +export class CachePointBlock implements CachePointBlockData, JSONSerializable<{ cachePoint: CachePointBlockData }> { + /** + * Discriminator for cache point. + */ + readonly type = 'cachePointBlock' as const + + /** + * The cache type. Currently only 'default' is supported. + */ + readonly cacheType: 'default' + + /** + * Optional TTL for the cache entry. See {@link CachePointBlockData.ttl} for the + * provider-specific value space. + */ + readonly ttl?: string + + constructor(data: CachePointBlockData) { + this.cacheType = data.cacheType + if (data.ttl !== undefined) { + this.ttl = data.ttl + } + } + + /** + * Serializes the CachePointBlock to a JSON-compatible ContentBlockData object. + * Called automatically by JSON.stringify(). + */ + toJSON(): { cachePoint: CachePointBlockData } { + return { + cachePoint: { + cacheType: this.cacheType, + ...(this.ttl !== undefined && { ttl: this.ttl }), + }, + } + } + + /** + * Creates a CachePointBlock instance from its wrapped data format. + * + * @param data - Wrapped CachePointBlockData to deserialize + * @returns CachePointBlock instance + */ + static fromJSON(data: { cachePoint: CachePointBlockData }): CachePointBlock { + return new CachePointBlock(data.cachePoint) + } +} + +/** + * Data for a JSON block. + */ +export interface JsonBlockData { + /** + * Structured JSON data. + */ + json: JSONValue +} + +/** + * JSON content block within a message. + * Used for structured data returned from tools or model responses. + */ +export class JsonBlock implements JsonBlockData, JSONSerializable { + /** + * Discriminator for JSON content. + */ + readonly type = 'jsonBlock' as const + + /** + * Structured JSON data. + */ + readonly json: JSONValue + + constructor(data: JsonBlockData) { + this.json = data.json + } + + /** + * Serializes the JsonBlock to a JSON-compatible JsonBlockData object. + * Called automatically by JSON.stringify(). + */ + toJSON(): JsonBlockData { + return { json: this.json } + } + + /** + * Creates a JsonBlock instance from JsonBlockData. + * + * @param data - JsonBlockData to deserialize + * @returns JsonBlock instance + */ + static fromJSON(data: JsonBlockData): JsonBlock { + return new JsonBlock(data) + } +} + +/** + * Reason why the model stopped generating content. + * + * - `cancelled` - Agent invocation was cancelled via `agent.cancel()` + * - `contentFiltered` - Content was filtered by safety mechanisms + * - `endTurn` - Natural end of the model's turn + * - `guardrailIntervened` - A guardrail policy stopped generation + * - `interrupt` - Agent execution was interrupted for human input + * - `maxTokens` - Maximum token limit was reached + * - `pauseTurn` - Model paused a long-running turn; the response should be sent back to continue + * - `refusal` - A streaming classifier intervened to handle a potential policy violation + * - `stopSequence` - A stop sequence was encountered + * - `toolUse` - Model wants to use a tool + * - `modelContextWindowExceeded` - Input exceeded the model's context window + */ +export type StopReason = + | 'cancelled' + | 'contentFiltered' + | 'endTurn' + | 'guardrailIntervened' + | 'interrupt' + | 'maxTokens' + | 'pauseTurn' + | 'refusal' + | 'stopSequence' + | 'toolUse' + | 'modelContextWindowExceeded' + | (string & {}) // Allow any string while preserving autocomplete for known values + +/** + * System prompt for guiding model behavior. + * Can be a simple string or an array of content blocks for advanced caching. + * + * @example + * ```typescript + * // Simple string + * const prompt: SystemPrompt = 'You are a helpful assistant' + * + * // Array with cache points for advanced caching + * const prompt: SystemPrompt = [ + * { textBlock: new TextBlock('You are a helpful assistant') }, + * { textBlock: new TextBlock(largeContextDocument) }, + * { cachePointBlock: new CachePointBlock({ cacheType: 'default' }) } + * ] + * ``` + */ +export type SystemPrompt = string | SystemContentBlock[] + +/** + * Data representation of a system prompt. + * Can be a simple string or an array of system content block data for advanced caching. + * + * This is the data interface counterpart to SystemPrompt, following the "Data" pattern. + */ +export type SystemPromptData = string | SystemContentBlockData[] + +/** + * Converts SystemPromptData to SystemPrompt by converting data blocks to class instances. + * If already in SystemPrompt format (class instances), returns as-is. + * + * @param data - System prompt data to convert + * @returns SystemPrompt with class-based content blocks + */ +export function systemPromptFromData(data: SystemPromptData | SystemPrompt): SystemPrompt { + if (typeof data === 'string') { + return data + } + + // Convert data format to class instances + return data.map((block) => { + if ('type' in block) { + return block + } else if ('cachePoint' in block) { + return new CachePointBlock(block.cachePoint) + } else if ('guardContent' in block) { + return new GuardContentBlock(block.guardContent) + } else if ('text' in block) { + return new TextBlock(block.text) + } else { + throw new Error('Unknown SystemContentBlockData type') + } + }) +} + +/** + * Converts a SystemPrompt to its data representation for serialization. + * + * @param prompt - System prompt to convert (string or content block array) + * @returns SystemPromptData suitable for JSON serialization + */ +export function systemPromptToData(prompt: SystemPrompt): SystemPromptData { + if (typeof prompt === 'string') { + return prompt + } + // Convert content blocks to their data representation + return prompt.map((block: SystemContentBlock) => block.toJSON()) as SystemContentBlockData[] +} + +/** + * A block of content within a system prompt. + * Supports text content, cache points, and guard content for prompt caching and guardrail evaluation. + * + * This is a discriminated union where the object key determines the block format. + */ +export type SystemContentBlockData = + | TextBlockData + | { cachePoint: CachePointBlockData } + | { guardContent: GuardContentBlockData } + +export type SystemContentBlock = TextBlock | CachePointBlock | GuardContentBlock + +/** + * Qualifier for guard content. + * Specifies how the content should be evaluated by guardrails. + * + * - `grounding_source` - Content to check for grounding/factuality + * - `query` - User query to evaluate + * - `guard_content` - General content for guardrail evaluation + */ +export type GuardQualifier = 'grounding_source' | 'query' | 'guard_content' + +/** + * Image format for guard content. + * Only formats supported by Bedrock guardrails. + */ +export type GuardImageFormat = 'png' | 'jpeg' + +/** + * Source for guard content image. + * Only supports raw bytes. + */ +export type GuardImageSource = { bytes: Uint8Array } + +/** + * Text content to be evaluated by guardrails. + */ +export interface GuardContentText { + /** + * Qualifiers that specify how this content should be evaluated. + */ + qualifiers: GuardQualifier[] + + /** + * The text content to be evaluated. + */ + text: string +} + +/** + * Image content to be evaluated by guardrails. + */ +export interface GuardContentImage { + /** + * Image format. + */ + format: GuardImageFormat + + /** + * Image source (bytes only). + */ + source: GuardImageSource +} + +/** + * Data for a guard content block. + * Can contain either text or image content for guardrail evaluation. + */ +export interface GuardContentBlockData { + /** + * Text content with evaluation qualifiers. + */ + text?: GuardContentText + + /** + * Image content with evaluation qualifiers. + */ + image?: GuardContentImage +} + +/** + * Guard content block for guardrail evaluation. + * Marks content that should be evaluated by guardrails for safety, grounding, or other policies. + * Can be used in both message content and system prompts. + */ +export class GuardContentBlock + implements GuardContentBlockData, JSONSerializable<{ guardContent: Serialized }> +{ + /** + * Discriminator for guard content. + */ + readonly type = 'guardContentBlock' as const + + /** + * Text content with evaluation qualifiers. + */ + readonly text?: GuardContentText + + /** + * Image content with evaluation qualifiers. + */ + readonly image?: GuardContentImage + + constructor(data: GuardContentBlockData) { + if (!data.text && !data.image) { + throw new Error('GuardContentBlock must have either text or image content') + } + if (data.text && data.image) { + throw new Error('GuardContentBlock cannot have both text and image content') + } + if (data.text) { + this.text = data.text + } + if (data.image) { + this.image = data.image + } + } + + /** + * Serializes the GuardContentBlock to a JSON-compatible ContentBlockData object. + * Called automatically by JSON.stringify(). + * Uint8Array image bytes are encoded as base64 string. + */ + toJSON(): { guardContent: Serialized } { + const data: Serialized = {} + if (this.text) { + data.text = this.text + } + if (this.image) { + data.image = { + format: this.image.format, + source: { bytes: encodeBase64(this.image.source.bytes) }, + } + } + return { guardContent: data } + } + + /** + * Creates a GuardContentBlock instance from its wrapped data format. + * Base64-encoded image bytes are decoded back to Uint8Array. + * + * @param data - Wrapped GuardContentBlockData to deserialize (accepts both string and Uint8Array for image bytes) + * @returns GuardContentBlock instance + */ + static fromJSON(data: { guardContent: MaybeSerializedInput }): GuardContentBlock { + const guardContent = data.guardContent + const result: GuardContentBlockData = {} + if (guardContent.text) { + result.text = guardContent.text + } + if (guardContent.image) { + const bytes = guardContent.image.source.bytes + result.image = { + format: guardContent.image.format, + source: { + bytes: typeof bytes === 'string' ? decodeBase64(bytes) : bytes, + }, + } + } + return new GuardContentBlock(result) + } +} + +/** + * Converts ContentBlockData to a ContentBlock instance. + * Handles all content block types including text, tool use/result, reasoning, cache points, guard content, and media blocks. + * + * @param data - The content block data to convert + * @returns A ContentBlock instance of the appropriate type + * @throws Error if the content block type is unknown + */ +export function contentBlockFromData(data: ContentBlockData): ContentBlock { + if ('text' in data) { + return new TextBlock(data.text) + } else if ('toolUse' in data) { + return new ToolUseBlock(data.toolUse) + } else if ('toolResult' in data) { + return ToolResultBlock.fromJSON(data) + } else if ('reasoning' in data) { + return ReasoningBlock.fromJSON(data) + } else if ('cachePoint' in data) { + return CachePointBlock.fromJSON(data) + } else if ('guardContent' in data) { + return GuardContentBlock.fromJSON(data) + } else if ('image' in data) { + return ImageBlock.fromJSON(data) + } else if ('video' in data) { + return VideoBlock.fromJSON(data) + } else if ('document' in data) { + return DocumentBlock.fromJSON(data) + } else if ('citations' in data) { + return CitationsBlock.fromJSON(data) + } else { + throw new Error('Unknown ContentBlockData type') + } +} diff --git a/strands-ts/src/types/serializable.ts b/strands-ts/src/types/serializable.ts new file mode 100644 index 0000000000..5151f2f1de --- /dev/null +++ b/strands-ts/src/types/serializable.ts @@ -0,0 +1,79 @@ +/** + * Serialization interfaces for state persistence. + * + * This module provides interfaces for objects that can serialize and deserialize + * their state, enabling persistence and restoration of runtime state. + * + * StateSerializable uses symbol-keyed methods to keep the serialization API internal, + * preventing accidental usage by customers (e.g., accessing agent.appState.toJSON() directly). + */ + +import type { JSONValue } from './json.js' + +/** + * Symbol for the serialization method on StateSerializable objects. + */ +export const stateToJSONSymbol = Symbol('StateSerializable.toJSON') + +/** + * Symbol for the deserialization method on StateSerializable objects. + */ +export const loadStateFromJSONSymbol = Symbol('StateSerializable.loadStateFromJSON') + +/** + * Interface for mutable state containers that can serialize and restore their state. + * Uses symbol-keyed methods to keep the API internal. + * + * Use JSONSerializable for immutable value objects (with static fromJSON). + * Use StateSerializable for mutable state that loads into an existing instance. + */ +export interface StateSerializable { + /** + * Serializes the state to a JSON value. + * + * @returns The serialized state + */ + [stateToJSONSymbol](): JSONValue + + /** + * Loads state from a previously serialized JSON value. + * + * @param json - The serialized state to load + */ + [loadStateFromJSONSymbol](json: JSONValue): void +} + +/** + * Type guard to check if an object implements StateSerializable. + * + * @param obj - The object to check + * @returns True if the object implements StateSerializable + */ +export function isStateSerializable(obj: unknown): obj is StateSerializable { + return ( + obj !== null && + typeof obj === 'object' && + typeof (obj as StateSerializable)[stateToJSONSymbol] === 'function' && + typeof (obj as StateSerializable)[loadStateFromJSONSymbol] === 'function' + ) +} + +/** + * Serializes a StateSerializable object to JSON. + * + * @param obj - The StateSerializable object to serialize + * @returns The serialized JSON value + */ +export function serializeStateSerializable(obj: StateSerializable): JSONValue { + return obj[stateToJSONSymbol]() +} + +/** + * Loads state from JSON into a StateSerializable object. + * + * @param obj - The StateSerializable object to load state into + * @param json - The JSON value to load + */ +export function loadStateSerializable(obj: StateSerializable, json: JSONValue): void { + obj[loadStateFromJSONSymbol](json) +} diff --git a/strands-ts/src/types/snapshot.ts b/strands-ts/src/types/snapshot.ts new file mode 100644 index 0000000000..857953aecf --- /dev/null +++ b/strands-ts/src/types/snapshot.ts @@ -0,0 +1,31 @@ +/** + * Shared snapshot types for agent and multi-agent snapshots. + */ + +import type { JSONValue } from './json.js' + +/** + * Current schema version of the snapshot format. + */ +export const SNAPSHOT_SCHEMA_VERSION = '1.0' + +/** + * Scope defines the context for snapshot data. + */ +export type Scope = 'agent' | 'multiAgent' + +/** + * Point-in-time capture of agent or orchestrator state. + */ +export interface Snapshot { + /** Scope identifying the snapshot context (agent or multi-agent). */ + scope: Scope + /** Schema version string for forward compatibility. */ + schemaVersion: string + /** ISO 8601 timestamp of when snapshot was created. */ + createdAt: string + /** Framework-owned state data. */ + data: Record + /** Application-owned data. Strands does not read or modify this. */ + appData: Record +} diff --git a/strands-ts/src/types/validation.ts b/strands-ts/src/types/validation.ts new file mode 100644 index 0000000000..54f34223b8 --- /dev/null +++ b/strands-ts/src/types/validation.ts @@ -0,0 +1,14 @@ +/** + * Ensures a value is defined, throwing an error if it's null or undefined. + * + * @param value - The value to check + * @param fieldName - Name of the field for error reporting + * @returns The value if defined + * @throws Error if value is null or undefined + */ +export function ensureDefined(value: T | null | undefined, fieldName: string): T { + if (value == null) { + throw new Error(`Expected ${fieldName} to be defined, but got ${value}`) + } + return value +} diff --git a/strands-ts/src/utils/shell-quote.ts b/strands-ts/src/utils/shell-quote.ts new file mode 100644 index 0000000000..a11fe10bcd --- /dev/null +++ b/strands-ts/src/utils/shell-quote.ts @@ -0,0 +1,13 @@ +/** + * Shell-escape a string for safe inclusion in a shell command. + * + * Wraps the value in single quotes and escapes any embedded single quotes + * using the '\'' pattern. Single quotes disable all shell expansion + * (variables, backticks, globbing), making this safe against injection. + * + * @param value - The string to escape. + * @returns The shell-escaped string wrapped in single quotes. + */ +export function shellQuote(value: string): string { + return "'" + value.replace(/'/g, "'\\''") + "'" +} diff --git a/strands-ts/src/vended-interventions/hitl/__tests__/hitl.test.ts b/strands-ts/src/vended-interventions/hitl/__tests__/hitl.test.ts new file mode 100644 index 0000000000..e43897d9b6 --- /dev/null +++ b/strands-ts/src/vended-interventions/hitl/__tests__/hitl.test.ts @@ -0,0 +1,428 @@ +import { describe, expect, it } from 'vitest' +import { HumanInTheLoop } from '../hitl.js' +import { Agent } from '../../../agent/agent.js' +import { MockMessageModel } from '../../../__fixtures__/mock-message-model.js' +import { createMockTool } from '../../../__fixtures__/tool-helpers.js' + +describe('HumanInTheLoop', () => { + describe('default config (interrupt/resume)', () => { + it('pauses agent with interrupt on any tool call', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'anyTool', toolUseId: 'tool-1', input: { x: 1 } }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + let toolExecuted = false + const tool = createMockTool('anyTool', () => { + toolExecuted = true + return 'result' + }) + + const agent = new Agent({ + model, + tools: [tool], + interventions: [new HumanInTheLoop()], + printer: false, + }) + + const result = await agent.invoke('Do something') + + expect(result.stopReason).toBe('interrupt') + expect(result.interrupts).toEqual([ + expect.objectContaining({ + name: 'strands:human-in-the-loop', + reason: expect.stringContaining('anyTool'), + }), + ]) + expect(toolExecuted).toBe(false) + }) + }) + + describe('inline mode (with ask callback)', () => { + it('allows tool execution when approved', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'myTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + let toolExecuted = false + const tool = createMockTool('myTool', () => { + toolExecuted = true + return 'executed' + }) + + const agent = new Agent({ + model, + tools: [tool], + interventions: [new HumanInTheLoop({ ask: async () => 'yes' })], + printer: false, + }) + + const result = await agent.invoke('Run tool') + expect(result.stopReason).toBe('endTurn') + expect(toolExecuted).toBe(true) + }) + + it('denies tool execution when rejected', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'myTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Understood' }) + + let toolExecuted = false + const tool = createMockTool('myTool', () => { + toolExecuted = true + return 'executed' + }) + + const agent = new Agent({ + model, + tools: [tool], + interventions: [new HumanInTheLoop({ ask: async () => 'no' })], + printer: false, + }) + + const result = await agent.invoke('Run tool') + expect(result.stopReason).toBe('endTurn') + expect(toolExecuted).toBe(false) + }) + }) + + describe('allowedTools config', () => { + it('does not prompt for tools in allowedTools', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'readFile', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + let toolExecuted = false + const tool = createMockTool('readFile', () => { + toolExecuted = true + return 'content' + }) + + const agent = new Agent({ + model, + tools: [tool], + interventions: [new HumanInTheLoop({ allowedTools: ['readFile'], ask: async () => 'no' })], + printer: false, + }) + + const result = await agent.invoke('Read it') + expect(result.stopReason).toBe('endTurn') + expect(toolExecuted).toBe(true) + }) + + it('prompts for tools not in allowedTools', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'deleteFile', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + let toolExecuted = false + const tool = createMockTool('deleteFile', () => { + toolExecuted = true + return 'deleted' + }) + + const agent = new Agent({ + model, + tools: [tool], + interventions: [new HumanInTheLoop({ allowedTools: ['readFile'], ask: async () => 'no' })], + printer: false, + }) + + await agent.invoke('Delete it') + expect(toolExecuted).toBe(false) + }) + + it('allows all tools except negated ones with "!" prefix', async () => { + const model = new MockMessageModel() + .addTurn([ + { type: 'toolUseBlock', name: 'readFile', toolUseId: 'tool-1', input: {} }, + { type: 'toolUseBlock', name: 'deleteFile', toolUseId: 'tool-2', input: {} }, + ]) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const execLog: string[] = [] + const readTool = createMockTool('readFile', () => { + execLog.push('read') + return 'content' + }) + const deleteTool = createMockTool('deleteFile', () => { + execLog.push('delete') + return 'deleted' + }) + + const agent = new Agent({ + model, + tools: [readTool, deleteTool], + interventions: [new HumanInTheLoop({ allowedTools: ['*', '!deleteFile'], ask: async () => 'no' })], + printer: false, + }) + + await agent.invoke('Do both') + + expect(execLog).toContain('read') + expect(execLog).not.toContain('delete') + }) + + it('allows all tools with wildcard "*"', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'dangerousTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + let toolExecuted = false + const tool = createMockTool('dangerousTool', () => { + toolExecuted = true + return 'ran' + }) + + const agent = new Agent({ + model, + tools: [tool], + interventions: [new HumanInTheLoop({ allowedTools: ['*'], ask: async () => 'no' })], + printer: false, + }) + + const result = await agent.invoke('Do it') + expect(result.stopReason).toBe('endTurn') + expect(toolExecuted).toBe(true) + }) + }) + + describe('ask callback', () => { + it('passes tool name and input in the prompt', async () => { + const prompts: string[] = [] + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'sendEmail', toolUseId: 'tool-1', input: { to: 'bob@example.com' } }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool('sendEmail', () => 'sent') + + const agent = new Agent({ + model, + tools: [tool], + interventions: [ + new HumanInTheLoop({ + ask: async (prompt) => { + prompts.push(prompt) + return 'yes' + }, + }), + ], + printer: false, + }) + + await agent.invoke('Send email') + + expect(prompts[0]).toContain('sendEmail') + expect(prompts[0]).toContain('bob@example.com') + }) + + it('supports custom evaluate function', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'myTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + let toolExecuted = false + const tool = createMockTool('myTool', () => { + toolExecuted = true + return 'executed' + }) + + const agent = new Agent({ + model, + tools: [tool], + interventions: [ + new HumanInTheLoop({ + ask: async () => 'magic-word', + evaluate: (response) => response === 'magic-word', + }), + ], + printer: false, + }) + + const result = await agent.invoke('Go') + expect(result.stopReason).toBe('endTurn') + expect(toolExecuted).toBe(true) + }) + }) + + describe('trust mode (enableTrust: true)', () => { + it('trusts a tool for the session when response is "t"', async () => { + let askCount = 0 + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'myTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'toolUseBlock', name: 'myTool', toolUseId: 'tool-2', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool('myTool', () => 'executed') + + const agent = new Agent({ + model, + tools: [tool], + interventions: [ + new HumanInTheLoop({ + enableTrust: true, + ask: async () => { + askCount++ + return 't' + }, + }), + ], + printer: false, + }) + + await agent.invoke('Run tool twice') + + expect(askCount).toBe(1) + }) + + it('does not trust when enableTrust is false even with "t" response', async () => { + let askCount = 0 + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'myTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'toolUseBlock', name: 'myTool', toolUseId: 'tool-2', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool('myTool', () => 'executed') + + const agent = new Agent({ + model, + tools: [tool], + interventions: [ + new HumanInTheLoop({ + enableTrust: false, + ask: async () => { + askCount++ + return 't' + }, + }), + ], + printer: false, + }) + + await agent.invoke('Run tool twice') + + // 't' is not recognized as approval when trust is disabled, so tool is denied both times + // but ask is still called both times (no trust memory) + expect(askCount).toBe(2) + }) + + it('"t" response also approves the current tool call', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'myTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + let toolExecuted = false + const tool = createMockTool('myTool', () => { + toolExecuted = true + return 'executed' + }) + + const agent = new Agent({ + model, + tools: [tool], + interventions: [new HumanInTheLoop({ enableTrust: true, ask: async () => 't' })], + printer: false, + }) + + const result = await agent.invoke('Run tool') + expect(result.stopReason).toBe('endTurn') + expect(toolExecuted).toBe(true) + }) + + it.each(['trust', 'T', 'TRUST'])('trusts when response is "%s"', async (trustResponse) => { + let askCount = 0 + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'myTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'toolUseBlock', name: 'myTool', toolUseId: 'tool-2', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool('myTool', () => 'executed') + + const agent = new Agent({ + model, + tools: [tool], + interventions: [ + new HumanInTheLoop({ + enableTrust: true, + ask: async () => { + askCount++ + return trustResponse + }, + }), + ], + printer: false, + }) + + await agent.invoke('Run tool twice') + expect(askCount).toBe(1) + }) + + it('supports custom evaluateTrust function', async () => { + let askCount = 0 + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'myTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'toolUseBlock', name: 'myTool', toolUseId: 'tool-2', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool('myTool', () => 'executed') + + const agent = new Agent({ + model, + tools: [tool], + interventions: [ + new HumanInTheLoop({ + enableTrust: true, + evaluateTrust: (r) => r === 'approve-and-remember', + ask: async () => { + askCount++ + return 'approve-and-remember' + }, + }), + ], + printer: false, + }) + + await agent.invoke('Run tool twice') + expect(askCount).toBe(1) + }) + + it('negated tools cannot be trusted even with "t" response', async () => { + let askCount = 0 + let toolExecuted = false + + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'dangerTool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'toolUseBlock', name: 'dangerTool', toolUseId: 'tool-2', input: {} }) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = createMockTool('dangerTool', () => { + toolExecuted = true + return 'ran' + }) + + const agent = new Agent({ + model, + tools: [tool], + interventions: [ + new HumanInTheLoop({ + allowedTools: ['*', '!dangerTool'], + enableTrust: true, + ask: async () => { + askCount++ + return 't' + }, + }), + ], + printer: false, + }) + + await agent.invoke('Run danger twice') + expect(askCount).toBe(2) + expect(toolExecuted).toBe(false) + }) + }) +}) diff --git a/strands-ts/src/vended-interventions/hitl/hitl.ts b/strands-ts/src/vended-interventions/hitl/hitl.ts new file mode 100644 index 0000000000..a2a3cda393 --- /dev/null +++ b/strands-ts/src/vended-interventions/hitl/hitl.ts @@ -0,0 +1,210 @@ +import { InterventionHandler } from '../../interventions/handler.js' +import { confirm, proceed, defaultEvaluate } from '../../interventions/actions.js' +import type { InterventionAction } from '../../interventions/actions.js' +import type { BeforeToolCallEvent } from '../../hooks/events.js' +import type { JSONValue } from '../../types/json.js' + +const TRUST_RESPONSES = new Set(['t', 'trust']) +const TRUSTED_TOOLS_KEY = 'hitl:trustedTools' + +/** + * CLI prompt that reads from stdin. + * Serializes prompts so concurrent tool calls don't collide on stdin. + */ +function createStdioAsk(includeTrust: boolean): (prompt: string) => Promise { + const options = includeTrust ? '(y/n/t)' : '(y/n)' + let queue: Promise = Promise.resolve() + + return (prompt: string) => { + const task = queue.then(async () => { + const { createInterface } = await import('node:readline') + const rl = createInterface({ input: process.stdin, output: process.stdout }) + return new Promise((resolve) => { + rl.question(`${prompt} ${options}: `, (answer) => { + rl.close() + resolve(answer.trim()) + }) + }) + }) + queue = task.catch(() => {}) + return task + } +} + +/** + * Configuration for the {@link HumanInTheLoop} intervention handler. + */ +export interface HumanInTheLoopConfig { + /** + * Tools that can execute WITHOUT human approval. All other tools require approval. + * + * - Use `'*'` to allow all tools. + * - Prefix with `!` to exclude specific tools from `'*'` (they still require approval). + * + * @example + * ```typescript + * // Only readFile and listDir run freely; everything else needs approval + * { allowedTools: ['readFile', 'listDir'] } + * + * // All tools run freely (HITL disabled) + * { allowedTools: ['*'] } + * + * // All tools run freely EXCEPT deleteFile and sendEmail + * { allowedTools: ['*', '!deleteFile', '!sendEmail'] } + * ``` + */ + allowedTools?: string[] + + /** + * When true, trust responses approve the tool AND remember it + * in `agent.appState` for the rest of the session (won't ask again). + * Works in both interrupt/resume and inline `ask` modes. + * + * Negated tools (`!tool`) cannot be trusted. + * + * Defaults to `false`. + */ + enableTrust?: boolean + + /** + * Custom trust response validator. Defaults to accepting `'t'`/`'trust'` (case-insensitive). + * When this returns true, the tool is approved AND trusted for the session. + * + * Only evaluated when `enableTrust` is true. + */ + evaluateTrust?: (response: JSONValue) => boolean + + /** + * Custom approval response validator. Defaults to accepting `true`, `'y'`/`'yes'` (case-insensitive). + */ + evaluate?: (response: JSONValue) => boolean + + /** + * Controls how the human's response is collected. + * + * - **Default** (omitted): uses interrupt/resume — agent pauses, caller resumes with response. + * - **`'stdio'`**: prompts via CLI readline (Node.js only). Agent blocks inline until human responds. + * - **Custom function**: your own async prompt logic (Slack, web UI, etc.). Agent blocks inline. + */ + ask?: ((prompt: string) => Promise) | 'stdio' +} + +/** + * Human-in-the-loop intervention handler that pauses agent execution + * before tool calls to request human approval. + * + * By default, ALL tools require approval and the agent pauses via interrupt/resume. + * Use `allowedTools` to whitelist tools that run freely, and `ask` to provide + * inline prompting (CLI, custom UI). + * + * @example + * ```typescript + * import { Agent } from '@strands-agents/sdk' + * import { HumanInTheLoop } from '@strands-agents/sdk/vended-interventions/hitl' + * + * // All tools require approval, agent pauses via interrupt (default) + * const agent = new Agent({ + * interventions: [new HumanInTheLoop()], + * }) + * + * // readFile runs freely, everything else pauses for approval + * const agent = new Agent({ + * interventions: [new HumanInTheLoop({ allowedTools: ['readFile'] })], + * }) + * + * // CLI mode — prompts in terminal inline + * const agent = new Agent({ + * interventions: [new HumanInTheLoop({ ask: 'stdio' })], + * }) + * + * // Custom UI — provide your own prompt function + * const agent = new Agent({ + * interventions: [new HumanInTheLoop({ + * ask: async (prompt) => await slackDM(userId, prompt), + * })], + * }) + * ``` + */ +export class HumanInTheLoop extends InterventionHandler { + readonly name = 'strands:human-in-the-loop' + + private readonly _allowedTools: Set + private readonly _enableTrust: boolean + private readonly _evaluateTrust: (response: JSONValue) => boolean + private readonly _evaluate: ((response: JSONValue) => boolean) | undefined + private readonly _ask: ((prompt: string) => Promise) | undefined + + constructor(config?: HumanInTheLoopConfig) { + super() + this._allowedTools = new Set(config?.allowedTools ?? []) + this._enableTrust = config?.enableTrust ?? false + this._evaluateTrust = config?.evaluateTrust ?? ((r: JSONValue): boolean => this._isTrustResponse(r)) + this._evaluate = config?.evaluate + this._ask = config?.ask === 'stdio' ? createStdioAsk(this._enableTrust) : config?.ask + } + + override async beforeToolCall(event: BeforeToolCallEvent): Promise { + const toolName = event.toolUse.name + if (!this._requiresApproval(event)) { + return proceed() + } + + const prompt = `Tool "${toolName}" requires human approval. Input: ${JSON.stringify(event.toolUse.input)}` + + const isNegated = this._allowedTools.has(`!${toolName}`) + + const evaluate = (response: JSONValue): boolean => { + if (!isNegated && this._enableTrust && this._evaluateTrust(response)) { + this._trustTool(event, toolName) + return true + } + return this._evaluate ? this._evaluate(response) : defaultEvaluate(response) + } + + if (!this._ask) { + return confirm(prompt, { evaluate }) + } + + const response = await this._ask(prompt) + + if (!isNegated && this._enableTrust && this._evaluateTrust(response)) { + this._trustTool(event, toolName) + return proceed() + } + + return confirm(prompt, { + response, + evaluate: this._evaluate ?? defaultEvaluate, + }) + } + + /** + * Precedence (first match wins): + * 1. Negated (`!tool`) → always requires approval (cannot be trusted) + * 2. Trusted at runtime via 't' response (stored in agent.appState) → runs freely + * 3. Wildcard (`*`) → runs freely + * 4. Explicitly listed → runs freely + * 5. Default → requires approval + */ + private _requiresApproval(event: BeforeToolCallEvent): boolean { + const toolName = event.toolUse.name + if (this._allowedTools.has(`!${toolName}`)) return true + const trusted = (event.agent.appState.get(TRUSTED_TOOLS_KEY) as string[] | undefined) ?? [] + if (trusted.includes(toolName)) return false + if (this._allowedTools.has('*')) return false + if (this._allowedTools.has(toolName)) return false + return true + } + + private _trustTool(event: BeforeToolCallEvent, toolName: string): void { + const trusted = (event.agent.appState.get(TRUSTED_TOOLS_KEY) as string[] | undefined) ?? [] + if (!trusted.includes(toolName)) { + event.agent.appState.set(TRUSTED_TOOLS_KEY, [...trusted, toolName]) + } + } + + private _isTrustResponse(response: JSONValue): boolean { + if (typeof response === 'string') return TRUST_RESPONSES.has(response.toLowerCase().trim()) + return false + } +} diff --git a/strands-ts/src/vended-interventions/hitl/index.ts b/strands-ts/src/vended-interventions/hitl/index.ts new file mode 100644 index 0000000000..eea6ec8e07 --- /dev/null +++ b/strands-ts/src/vended-interventions/hitl/index.ts @@ -0,0 +1,24 @@ +/** + * Human-in-the-loop intervention for Strands Agents. + * + * Pauses agent execution before tool calls to request human approval. + * Defaults to interrupt/resume mode for stateless deployments. + * Pass `ask: 'stdio'` for CLI prompting or a custom `ask` function for other UIs. + * + * @example + * ```typescript + * import { Agent } from '@strands-agents/sdk' + * import { HumanInTheLoop } from '@strands-agents/sdk/vended-interventions/hitl' + * + * const agent = new Agent({ + * tools: [deleteTool, readTool], + * interventions: [new HumanInTheLoop({ allowedTools: ['readTool'] })], + * }) + * + * // Default: agent pauses with stopReason 'interrupt', caller resumes with response + * const result = await agent.invoke('Delete the file') + * ``` + */ + +export { HumanInTheLoop } from './hitl.js' +export type { HumanInTheLoopConfig } from './hitl.js' diff --git a/strands-ts/src/vended-interventions/steering/__tests__/handler.test.ts b/strands-ts/src/vended-interventions/steering/__tests__/handler.test.ts new file mode 100644 index 0000000000..7b18efe1f8 --- /dev/null +++ b/strands-ts/src/vended-interventions/steering/__tests__/handler.test.ts @@ -0,0 +1,196 @@ +import { describe, expect, it, vi } from 'vitest' +import { Agent } from '../../../agent/agent.js' +import { HookRegistryImplementation } from '../../../hooks/registry.js' +import { AfterModelCallEvent, BeforeToolCallEvent } from '../../../hooks/events.js' +import { Interrupt, InterruptState } from '../../../interrupt.js' +import { confirm, guide, type Confirm, type Guide, type Proceed } from '../../../interventions/actions.js' +import { Message, TextBlock } from '../../../types/messages.js' +import type { ToolUse } from '../../../tools/types.js' +import type { LocalAgent } from '../../../types/agent.js' +import { SteeringHandler } from '../handlers/handler.js' +import type { SteeringContextData, SteeringContextProvider } from '../providers/context-provider.js' + +function getHookRegistry(agent: Agent): HookRegistryImplementation { + return (agent as unknown as { _hooksRegistry: HookRegistryImplementation })._hooksRegistry +} + +describe('SteeringHandler', () => { + const toolUse = { name: 'searchWeb', toolUseId: 'tu-1', input: { q: 'hi' } } + + function makeBeforeToolCallEvent(agent: LocalAgent): BeforeToolCallEvent { + return new BeforeToolCallEvent({ agent, toolUse, tool: undefined, invocationState: {} }) + } + + function makeAfterModelCallEvent(agent: LocalAgent): AfterModelCallEvent { + return new AfterModelCallEvent({ + agent, + model: {} as never, + invocationState: {}, + attemptCount: 0, + stopData: { + message: new Message({ role: 'assistant', content: [new TextBlock('response')] }), + stopReason: 'endTurn', + }, + }) + } + + it('routes beforeToolCall to subclass override with the event', async () => { + const seen: { agent?: LocalAgent; toolUse?: ToolUse } = {} + + class Spy extends SteeringHandler { + override readonly name = 'spy' + override async beforeToolCall(event: BeforeToolCallEvent): Promise { + seen.agent = event.agent + seen.toolUse = event.toolUse + return guide('try again') + } + } + + const agent = new Agent({ interventions: [new Spy()] }) + await agent.initialize() + const event = makeBeforeToolCallEvent(agent) + await getHookRegistry(agent).invokeCallbacks(event) + + expect(seen.agent).toBe(agent) + expect(seen.toolUse).toEqual(toolUse) + expect(event.cancel).toContain('GUIDANCE:') + expect(event.cancel).toContain('try again') + }) + + it('routes afterModelCall to subclass override with the event', async () => { + const seen: { message?: Message; stopReason?: string } = {} + + class Spy extends SteeringHandler { + override readonly name = 'spy' + override async afterModelCall(event: AfterModelCallEvent): Promise { + if (!event.stopData) return { type: 'proceed' } + seen.message = event.stopData.message + seen.stopReason = event.stopData.stopReason + return guide('be terser') + } + } + + const agent = new Agent({ interventions: [new Spy()] }) + await agent.initialize() + const event = makeAfterModelCallEvent(agent) + await getHookRegistry(agent).invokeCallbacks(event) + + expect(seen.message).toBeDefined() + expect(seen.stopReason).toBe('endTurn') + expect(event.retry).toBe(true) + }) + + it('exposes provider context to subclasses via getSteeringContext', async () => { + const fakeProvider: SteeringContextProvider = { + name: 'fake', + observeAgent() {}, + get context(): SteeringContextData { + return { type: 'fake', tokens: 42 } + }, + } + + let observedContext: SteeringContextData[] | undefined + + class ContextReader extends SteeringHandler { + override readonly name = 'context-reader' + override async beforeToolCall(): Promise { + observedContext = this.getSteeringContext() + return { type: 'proceed' } + } + } + + const agent = new Agent({ interventions: [new ContextReader({ contextProviders: [fakeProvider] })] }) + await agent.initialize() + await getHookRegistry(agent).invokeCallbacks(makeBeforeToolCallEvent(agent)) + + expect(observedContext).toEqual([{ type: 'fake', tokens: 42 }]) + }) + + it('siblings with distinct names can coexist on one agent', () => { + class A extends SteeringHandler { + override readonly name = 'steer:tool' + } + class B extends SteeringHandler { + override readonly name = 'steer:model' + } + + expect(() => new Agent({ interventions: [new A(), new B()] })).not.toThrow() + }) + + it('does not invoke afterModelCall body when stopData is missing', async () => { + const called = vi.fn() + + class Spy extends SteeringHandler { + override readonly name = 'spy' + override async afterModelCall(event: AfterModelCallEvent): Promise { + if (event.stopData) called() + return { type: 'proceed' } + } + } + + const agent = new Agent({ interventions: [new Spy()] }) + await agent.initialize() + const event = new AfterModelCallEvent({ + agent, + model: {} as never, + invocationState: {}, + attemptCount: 0, + }) + await getHookRegistry(agent).invokeCallbacks(event) + + expect(called).not.toHaveBeenCalled() + }) + + it('confirm decision flows through the interrupt system on resume (approved)', async () => { + class Approver extends SteeringHandler { + override readonly name = 'approver' + override async beforeToolCall(): Promise { + return confirm('approve searchWeb?') + } + } + + const agent = new Agent({ interventions: [new Approver()] }) + await agent.initialize() + + // Preload an approval response so event.interrupt() returns it instead of pausing + const interruptId = `hook:beforeToolCall:${toolUse.toolUseId}:approver` + const interruptState = (agent as unknown as { _interruptState: InterruptState })._interruptState + interruptState.interrupts[interruptId] = new Interrupt({ + id: interruptId, + name: 'approver', + response: 'yes' as never, + source: 'hook', + }) + + const event = makeBeforeToolCallEvent(agent) + await getHookRegistry(agent).invokeCallbacks(event) + + expect(event.cancel).toBe(false) + }) + + it('confirm decision sets cancel when human denies', async () => { + class Approver extends SteeringHandler { + override readonly name = 'approver' + override async beforeToolCall(): Promise { + return confirm('approve searchWeb?') + } + } + + const agent = new Agent({ interventions: [new Approver()] }) + await agent.initialize() + + const interruptId = `hook:beforeToolCall:${toolUse.toolUseId}:approver` + const interruptState = (agent as unknown as { _interruptState: InterruptState })._interruptState + interruptState.interrupts[interruptId] = new Interrupt({ + id: interruptId, + name: 'approver', + response: 'no' as never, + source: 'hook', + }) + + const event = makeBeforeToolCallEvent(agent) + await getHookRegistry(agent).invokeCallbacks(event) + + expect(event.cancel).toBe('CONFIRMATION_FAILED: approve searchWeb?') + }) +}) diff --git a/strands-ts/src/vended-interventions/steering/__tests__/llm.test.ts b/strands-ts/src/vended-interventions/steering/__tests__/llm.test.ts new file mode 100644 index 0000000000..52a3d9de20 --- /dev/null +++ b/strands-ts/src/vended-interventions/steering/__tests__/llm.test.ts @@ -0,0 +1,72 @@ +import { describe, expect, it, vi } from 'vitest' +import { Agent } from '../../../agent/agent.js' +import { BeforeToolCallEvent } from '../../../hooks/events.js' +import { HookRegistryImplementation } from '../../../hooks/registry.js' +import { MockMessageModel } from '../../../__fixtures__/mock-message-model.js' +import { LLMSteeringHandler } from '../handlers/llm.js' + +function getHookRegistry(agent: Agent): HookRegistryImplementation { + return (agent as unknown as { _hooksRegistry: HookRegistryImplementation })._hooksRegistry +} + +function structuredOutputModel(decision: { type: 'proceed' | 'guide' | 'confirm'; reason: string }): MockMessageModel { + return new MockMessageModel().addTurn({ + type: 'toolUseBlock', + name: 'strands_structured_output', + toolUseId: 'inner-1', + input: decision, + }) +} + +describe('LLMSteeringHandler', () => { + const toolUse = { name: 'searchWeb', toolUseId: 'tu-1', input: { q: 'hi' } } + + it("defaults to the parent agent's model when none is configured", async () => { + const model = structuredOutputModel({ type: 'proceed', reason: 'no concerning patterns' }) + const streamSpy = vi.spyOn(model, 'stream') + + const handler = new LLMSteeringHandler({ + systemPrompt: 'You are a steering agent.', + contextProviders: [], + }) + const agent = new Agent({ model, interventions: [handler] }) + await agent.initialize() + + const event = new BeforeToolCallEvent({ agent, toolUse, tool: undefined, invocationState: {} }) + await getHookRegistry(agent).invokeCallbacks(event) + + expect(streamSpy).toHaveBeenCalledTimes(1) + expect(event.cancel).toBe(false) + }) + + it('uses the configured model in preference to the agent model', async () => { + const agentModel = new MockMessageModel().addTurn({ type: 'textBlock', text: 'unused' }) + const configuredModel = structuredOutputModel({ type: 'proceed', reason: 'ok' }) + const agentStreamSpy = vi.spyOn(agentModel, 'stream') + const configuredStreamSpy = vi.spyOn(configuredModel, 'stream') + + const handler = new LLMSteeringHandler({ + systemPrompt: 'You are a steering agent.', + model: configuredModel, + contextProviders: [], + }) + const agent = new Agent({ model: agentModel, interventions: [handler] }) + await agent.initialize() + + const event = new BeforeToolCallEvent({ agent, toolUse, tool: undefined, invocationState: {} }) + await getHookRegistry(agent).invokeCallbacks(event) + + expect(configuredStreamSpy).toHaveBeenCalledTimes(1) + expect(agentStreamSpy).not.toHaveBeenCalled() + }) + + it('throws when no model is configured and the handler has no parent agent', async () => { + const handler = new LLMSteeringHandler({ + systemPrompt: 'You are a steering agent.', + contextProviders: [], + }) + + // Detached: never attached to an agent, never observed. + await expect(handler.beforeToolCall({ toolUse } as unknown as BeforeToolCallEvent)).rejects.toThrow(/no model/i) + }) +}) diff --git a/strands-ts/src/vended-interventions/steering/__tests__/tool-ledger.test.ts b/strands-ts/src/vended-interventions/steering/__tests__/tool-ledger.test.ts new file mode 100644 index 0000000000..c0227ffab7 --- /dev/null +++ b/strands-ts/src/vended-interventions/steering/__tests__/tool-ledger.test.ts @@ -0,0 +1,116 @@ +import { describe, expect, it } from 'vitest' +import { Agent } from '../../../agent/agent.js' +import { AfterToolCallEvent, BeforeToolCallEvent } from '../../../hooks/events.js' +import type { HookRegistryImplementation } from '../../../hooks/registry.js' +import { TextBlock, ToolResultBlock } from '../../../types/messages.js' +import { ToolLedgerProvider } from '../providers/tool-ledger.js' + +describe('ToolLedgerProvider', () => { + const toolUse = { name: 'searchWeb', toolUseId: 'tu-1', input: { q: 'hi' } } + + function setupAgent(provider: ToolLedgerProvider): { + agent: Agent + hookRegistry: HookRegistryImplementation + } { + const agent = new Agent() + const hookRegistry = (agent as unknown as { _hooksRegistry: HookRegistryImplementation })._hooksRegistry + provider.observeAgent(agent) + return { agent, hookRegistry } + } + + function makeBefore(agent: Agent): BeforeToolCallEvent { + return new BeforeToolCallEvent({ agent, toolUse, tool: undefined, invocationState: {} }) + } + + function makeAfter(agent: Agent, status: 'success' | 'error' = 'success', error?: Error): AfterToolCallEvent { + return new AfterToolCallEvent({ + agent, + toolUse, + tool: undefined, + result: new ToolResultBlock({ + toolUseId: toolUse.toolUseId, + status, + content: [new TextBlock('result text')], + ...(error !== undefined && { error }), + }), + invocationState: {}, + ...(error !== undefined && { error }), + }) + } + + it('records pending entry on beforeToolCall', async () => { + const provider = new ToolLedgerProvider() + const { agent, hookRegistry } = setupAgent(provider) + + expect(provider.context.type).toBe('toolLedger') + expect(provider.context.calls).toEqual([]) + + await hookRegistry.invokeCallbacks(makeBefore(agent)) + + const calls = provider.context.calls as Array> + expect(calls).toHaveLength(1) + expect(calls[0]).toMatchObject({ + id: 'tu-1', + name: 'searchWeb', + args: { q: 'hi' }, + status: 'pending', + }) + }) + + it('flips pending to success after afterToolCall', async () => { + const provider = new ToolLedgerProvider() + const { agent, hookRegistry } = setupAgent(provider) + + await hookRegistry.invokeCallbacks(makeBefore(agent)) + await hookRegistry.invokeCallbacks(makeAfter(agent, 'success')) + + const calls = provider.context.calls as Array> + expect(calls).toHaveLength(1) + expect(calls[0]).toMatchObject({ + id: 'tu-1', + name: 'searchWeb', + args: { q: 'hi' }, + status: 'success', + error: null, + endTime: expect.any(String), + }) + }) + + it('records error status and message', async () => { + const provider = new ToolLedgerProvider() + const { agent, hookRegistry } = setupAgent(provider) + + await hookRegistry.invokeCallbacks(makeBefore(agent)) + await hookRegistry.invokeCallbacks(makeAfter(agent, 'error', new Error('boom'))) + + const calls = provider.context.calls as Array> + expect(calls[0]).toMatchObject({ + id: 'tu-1', + name: 'searchWeb', + args: { q: 'hi' }, + status: 'error', + error: 'boom', + endTime: expect.any(String), + }) + }) + + it('drops oldest entries when ledger exceeds maxEntries', async () => { + const provider = new ToolLedgerProvider({ maxEntries: 2 }) + const { agent, hookRegistry } = setupAgent(provider) + + for (const id of ['a', 'b', 'c']) { + await hookRegistry.invokeCallbacks( + new BeforeToolCallEvent({ + agent, + toolUse: { name: 't', toolUseId: id, input: {} }, + tool: undefined, + invocationState: {}, + }) + ) + } + + const calls = provider.context.calls as Array> + expect(calls).toHaveLength(2) + expect(calls.map((c) => c.id)).toEqual(['b', 'c']) + }) +}) diff --git a/strands-ts/src/vended-interventions/steering/handlers/handler.ts b/strands-ts/src/vended-interventions/steering/handlers/handler.ts new file mode 100644 index 0000000000..b7bb2609dd --- /dev/null +++ b/strands-ts/src/vended-interventions/steering/handlers/handler.ts @@ -0,0 +1,93 @@ +/** + * Steering handler base class for providing contextual guidance to agents. + * + * Subclass {@link SteeringHandler} and override {@link beforeToolCall} and/or + * {@link afterModelCall}. These carry a narrowed steering contract + * (Proceed | Guide | Confirm for tool calls, Proceed | Guide for model output) + * — the wider intervention vocabulary (Deny, Transform) is excluded by the + * return type, so out-of-contract actions are caught at compile time. + * + * @example + * ```typescript + * class MySteeringHandler extends SteeringHandler { + * override readonly name = 'my-steering' + * + * override async beforeToolCall(event) { + * if (event.toolUse.name === 'dangerous_tool') { + * return guide('This tool requires extra caution.') + * } + * return proceed() + * } + * } + * + * const agent = new Agent({ tools: [...], interventions: [new MySteeringHandler()] }) + * ``` + */ + +import type { AfterModelCallEvent, BeforeToolCallEvent } from '../../../hooks/events.js' +import { InterventionHandler, type Awaitable } from '../../../interventions/handler.js' +import type { LifecycleObserver } from '../../../types/lifecycle-observer.js' +import { proceed, type Confirm, type Guide, type Proceed } from '../../../interventions/actions.js' +import type { LocalAgent } from '../../../types/agent.js' +import type { SteeringContextData, SteeringContextProvider } from '../providers/context-provider.js' + +/** + * Configuration shared by all steering handlers. + */ +export interface SteeringHandlerConfig { + /** Providers that supply evaluation context. */ + contextProviders?: SteeringContextProvider[] +} + +/** + * Base class for steering handlers that provide contextual guidance to agents. + * + * Steering handlers accept context providers that observe agent activity, and + * use the accumulated context to make guidance decisions. The handler is an + * {@link InterventionHandler} — pass it via `interventions:` on the agent. + * + * Subclasses must declare a `name` (inherited as `abstract` from + * {@link InterventionHandler}). When attaching multiple steering handlers to + * one agent, ensure their names are distinct — `InterventionRegistry` rejects + * duplicates. + */ +export abstract class SteeringHandler extends InterventionHandler implements LifecycleObserver { + abstract override readonly name: string + + private readonly _contextProviders: SteeringContextProvider[] + + constructor(config?: SteeringHandlerConfig) { + super() + this._contextProviders = config?.contextProviders ?? [] + } + + // --------------------------------------------------------------------------- + // Steering moments — narrowed return types reject out-of-contract actions. + // --------------------------------------------------------------------------- + + override beforeToolCall(_event: BeforeToolCallEvent): Awaitable { + return proceed() + } + + override afterModelCall(_event: AfterModelCallEvent): Awaitable { + return proceed() + } + + // --------------------------------------------------------------------------- + // Lifecycle observer — forward to providers so they can self-register hooks. + // --------------------------------------------------------------------------- + + async observeAgent(agent: LocalAgent): Promise { + for (const provider of this._contextProviders) { + await provider.observeAgent(agent) + } + } + + /** + * Collect context from all registered providers. Subclasses (and tests) + * may call this to inspect the accumulated provider snapshots. + */ + getSteeringContext(): SteeringContextData[] { + return this._contextProviders.map((provider) => provider.context) + } +} diff --git a/strands-ts/src/vended-interventions/steering/handlers/llm.ts b/strands-ts/src/vended-interventions/steering/handlers/llm.ts new file mode 100644 index 0000000000..03719a55ac --- /dev/null +++ b/strands-ts/src/vended-interventions/steering/handlers/llm.ts @@ -0,0 +1,245 @@ +/** + * LLM-based steering handler that uses an LLM to provide contextual guidance. + */ + +import { z } from 'zod' +import { Agent } from '../../../agent/agent.js' +import { confirm, guide, proceed, type Confirm, type Guide, type Proceed } from '../../../interventions/actions.js' +import type { Model } from '../../../models/model.js' +import type { ContentBlock, SystemPrompt } from '../../../types/messages.js' +import { CachePointBlock, TextBlock } from '../../../types/messages.js' +import type { ToolUse } from '../../../tools/types.js' +import type { BeforeToolCallEvent } from '../../../hooks/events.js' +import type { LocalAgent } from '../../../types/agent.js' +import type { SteeringContextData, SteeringContextProvider } from '../providers/context-provider.js' +import { ToolLedgerProvider } from '../providers/tool-ledger.js' +import { SteeringHandler } from './handler.js' + +// --------------------------------------------------------------------------- +// Prompt building +// --------------------------------------------------------------------------- + +/** + * Builds the evaluation prompt sent to the steering LLM. + * Return a string for simple prompts, or ContentBlock[] to use cache points. + */ +export type PromptBuilder = (context: SteeringContextData[], toolUse?: ToolUse) => string | ContentBlock[] + +/** + * Default prompt builder. Returns content blocks with a cache point + * between static instructions and dynamic context/event data. + * + * See: https://github.com/strands-agents/agent-sop + */ +function defaultPromptBuilder(context: SteeringContextData[], toolUse?: ToolUse): ContentBlock[] { + const contextStr = context.length > 0 ? JSON.stringify(context, null, 2) : 'No context available' + + const actionType = toolUse ? 'tool call' : 'action' + const actionTypeTitle = toolUse ? 'Tool Call' : 'Action' + const eventDescription = toolUse + ? `Tool: ${toolUse.name}\nArguments: ${JSON.stringify(toolUse.input, null, 2)}` + : 'General evaluation' + + const hasLedger = context.some((c) => c.type === 'toolLedger') + const ledgerExplanation = hasLedger + ? ` + +### Understanding Ledger Tool States + +If the context includes a ledger with tool_calls, the "status" field indicates: + +- **"pending"**: The tool is CURRENTLY being evaluated by you (the steering agent). +This is NOT a duplicate call — it's the tool you're deciding whether to approve. +The tool has NOT started executing yet. +- **"success"**: The tool completed successfully in a previous turn +- **"error"**: The tool failed or was cancelled in a previous turn + +**IMPORTANT**: When you see a tool with status="pending" that matches the tool you're evaluating, +that IS the current tool being evaluated. It is NOT already executing or a duplicate.` + : '' + + // Static framing (cached): role, constraints, decision criteria, ledger semantics. + const instructions = `# Steering Evaluation + +## Overview + +You are a STEERING AGENT that evaluates a ${actionType} that ANOTHER AGENT is attempting to make. +Your job is to provide contextual guidance to help the other agent navigate workflows effectively. +You act as a safety net that can intervene when patterns in the context data suggest the agent +should try a different approach or get human input. + +**YOUR ROLE:** +- Analyze context data for concerning patterns (repeated failures, inappropriate timing, etc.) +- Provide just-in-time guidance when the agent is going down an ineffective path +- Allow normal operations to proceed when context shows no issues + +**CRITICAL CONSTRAINTS:** +- Base decisions ONLY on the context data provided +- Do NOT use external knowledge about domains, URLs, or tool purposes +- Do NOT make assumptions about what tools "should" or "shouldn't" do +- Focus ONLY on patterns in the context data${ledgerExplanation} + +## Steps + +### 1. Analyze the ${actionTypeTitle} + +Review ONLY the context data. Look for patterns in the data that indicate: + +- Previous failures or successes with this tool +- Frequency of attempts +- Any relevant tracking information + +**Constraints:** +- You MUST base analysis ONLY on the provided context data +- You MUST NOT use external knowledge about tool purposes or domains +- You SHOULD identify patterns in the context data +- You MAY reference relevant context data to inform your decision + +### 2. Make Steering Decision + +**Constraints:** +- You MUST respond with exactly one of: "proceed", "guide", or "confirm" +- You MUST base the decision ONLY on context data patterns +- Your reason will be shown to the AGENT as guidance + +**Decision Options:** +- "proceed" if context data shows no concerning patterns +- "guide" if context data shows patterns requiring intervention +- "confirm" if context data shows patterns requiring human input` + + // Dynamic block (uncached): per-call context and event payload. + const dynamic = `## Context + +${contextStr} + +## Event to Evaluate + +${eventDescription}` + + return [new TextBlock(instructions), new CachePointBlock({ cacheType: 'default' }), new TextBlock(dynamic)] +} + +// --------------------------------------------------------------------------- +// LLM steering handler +// --------------------------------------------------------------------------- + +/** + * Configuration for the LLMSteeringHandler. + */ +export interface LLMSteeringHandlerConfig { + /** System prompt defining the steering guidance rules. */ + systemPrompt: SystemPrompt + + /** Model for steering evaluation. Defaults to the parent agent's model. */ + model?: Model + + /** Custom prompt builder for evaluation prompts. Defaults to defaultPromptBuilder. */ + promptBuilder?: PromptBuilder + + /** + * Context providers for populating steering context. + * Defaults to [new ToolLedgerProvider()] if undefined. Pass an empty array to disable. + */ + contextProviders?: SteeringContextProvider[] + + /** + * Identifier for this handler instance. Defaults to `'strands:llm-steering-handler'`. + * Override when attaching multiple LLM steering handlers to the same agent. + */ + name?: string +} + +/** Schema returned by the steering LLM. */ +const STEERING_DECISION = z.object({ + type: z + .enum(['proceed', 'guide', 'confirm']) + .describe( + "Steering decision: 'proceed' to continue, 'guide' to provide feedback, 'confirm' to pause for human approval" + ), + reason: z.string().describe('Clear explanation of the steering decision and any guidance provided'), +}) + +type SteeringDecision = z.infer + +/** + * Steering handler that uses an LLM to provide contextual guidance. + * + * Uses natural language prompts to evaluate tool calls and produce an + * intervention action. + * + * Only `beforeToolCall` is implemented — model-output steering is not + * delegated to the LLM. Subclass and override `afterModelCall` (which + * carries the narrowed `Proceed | Guide` return) to add LLM-driven + * evaluation of model responses. + * + * @example + * ```typescript + * import { Agent } from '@strands-agents/sdk' + * import { LLMSteeringHandler } from '@strands-agents/sdk/vended-interventions/steering' + * + * const handler = new LLMSteeringHandler({ + * systemPrompt: `You ensure emails maintain a cheerful, positive tone.`, + * }) + * + * const agent = new Agent({ tools: [sendEmail], interventions: [handler] }) + * ``` + */ +export class LLMSteeringHandler extends SteeringHandler { + override readonly name: string + + private readonly _promptBuilder: PromptBuilder + private readonly _configuredModel: Model | undefined + private _agentModel: Model | undefined + private readonly _systemPrompt: SystemPrompt + + constructor(config: LLMSteeringHandlerConfig) { + const contextProviders = + config.contextProviders === undefined ? [new ToolLedgerProvider()] : config.contextProviders + super({ contextProviders }) + + this.name = config.name ?? 'strands:llm-steering-handler' + this._promptBuilder = config.promptBuilder ?? defaultPromptBuilder + this._configuredModel = config.model + this._systemPrompt = config.systemPrompt + } + + override async observeAgent(agent: LocalAgent): Promise { + this._agentModel = agent.model + await super.observeAgent(agent) + } + + override async beforeToolCall(event: BeforeToolCallEvent): Promise { + const context = this.getSteeringContext() + const prompt = this._promptBuilder(context, event.toolUse) + const decision = await this._invoke(prompt) + + switch (decision.type) { + case 'proceed': + return proceed({ reason: decision.reason }) + case 'guide': + return guide(decision.reason) + case 'confirm': + return confirm(decision.reason, { reason: decision.reason }) + } + } + + // Constructs a fresh inner agent per call so the handler has no shared + // mutable state between invocations — this keeps it safe to attach to + // multiple parent agents (whose tool calls may evaluate concurrently). + private async _invoke(prompt: string | ContentBlock[]): Promise { + const model = this._configuredModel ?? this._agentModel + if (!model) { + throw new Error( + 'LLMSteeringHandler has no model — pass `model` in config, or attach the handler to an agent before invoking it.' + ) + } + const inner = new Agent({ + model, + systemPrompt: this._systemPrompt, + structuredOutputSchema: STEERING_DECISION, + printer: false, + }) + const result = await inner.invoke(prompt) + return STEERING_DECISION.parse(result.structuredOutput) + } +} diff --git a/strands-ts/src/vended-interventions/steering/index.ts b/strands-ts/src/vended-interventions/steering/index.ts new file mode 100644 index 0000000000..f6f1686e5f --- /dev/null +++ b/strands-ts/src/vended-interventions/steering/index.ts @@ -0,0 +1,36 @@ +/** + * Steering system for Strands agents. + * + * Provides contextual guidance for agents through modular prompting. + * Instead of front-loading all instructions, steering handlers provide + * just-in-time feedback based on context data from registered providers. + * + * Steering handlers are {@link InterventionHandler}s — register them on the + * agent via the `interventions:` option, not `plugins:`. + * + * Core components: + * - SteeringHandler: base class for guidance logic + * - SteeringContextProvider: interface for context data providers + * + * @example + * ```typescript + * import { Agent } from '@strands-agents/sdk' + * import { LLMSteeringHandler } from '@strands-agents/sdk/vended-interventions/steering' + * + * const handler = new LLMSteeringHandler({ + * systemPrompt: '...', + * model: new BedrockModel(), + * }) + * const agent = new Agent({ tools: [...], interventions: [handler] }) + * ``` + */ + +// Core +export type { SteeringContextData, SteeringContextProvider } from './providers/context-provider.js' +export { SteeringHandler, type SteeringHandlerConfig } from './handlers/handler.js' + +// Context providers +export { ToolLedgerProvider, type ToolLedgerProviderConfig } from './providers/tool-ledger.js' + +// Handler implementations +export { LLMSteeringHandler, type LLMSteeringHandlerConfig, type PromptBuilder } from './handlers/llm.js' diff --git a/strands-ts/src/vended-interventions/steering/providers/context-provider.ts b/strands-ts/src/vended-interventions/steering/providers/context-provider.ts new file mode 100644 index 0000000000..81a7e439a1 --- /dev/null +++ b/strands-ts/src/vended-interventions/steering/providers/context-provider.ts @@ -0,0 +1,59 @@ +/** + * Steering context provider interface. + * + * Providers track agent activity and supply context data to steering handlers + * for evaluation decisions. + */ + +import type { LocalAgent } from '../../../types/agent.js' +import type { LifecycleObserver } from '../../../types/lifecycle-observer.js' +import type { JSONValue } from '../../../types/json.js' + +/** + * Context data returned by a SteeringContextProvider. + * The type field identifies which provider produced the data. + */ +export interface SteeringContextData { + /** Discriminator identifying the context provider. */ + readonly type: string + /** Additional context fields. */ + [key: string]: JSONValue +} + +/** + * A passive observer that accumulates data from agent lifecycle events. + * + * Providers self-register hook callbacks via {@link LifecycleObserver.observeAgent}, + * which the owning {@link SteeringHandler} invokes once at registration time. + * + * Providers expose accumulated state through the `context` getter, which the + * handler reads when making steering decisions. + * + * @example + * ```typescript + * class CostTracker implements SteeringContextProvider { + * readonly name = 'costTracker' + * private _toolCalls = 0 + * + * observeAgent(agent: LocalAgent): void { + * agent.addHook(AfterToolCallEvent, () => { + * this._toolCalls += 1 + * }) + * } + * + * get context(): SteeringContextData { + * return { type: 'costTracker', toolCalls: this._toolCalls } + * } + * } + * ``` + */ +export interface SteeringContextProvider extends LifecycleObserver { + /** Identifier for this provider instance. */ + readonly name: string + + /** Subscribe to hooks on the owning agent. Required for providers. */ + observeAgent(agent: LocalAgent): void | Promise + + /** Return the current context snapshot for steering evaluation. */ + get context(): SteeringContextData +} diff --git a/strands-ts/src/vended-interventions/steering/providers/tool-ledger.ts b/strands-ts/src/vended-interventions/steering/providers/tool-ledger.ts new file mode 100644 index 0000000000..cb239b2a66 --- /dev/null +++ b/strands-ts/src/vended-interventions/steering/providers/tool-ledger.ts @@ -0,0 +1,117 @@ +/** + * Ledger context provider for comprehensive agent activity tracking. + * + * Tracks tool call history with inputs, outputs, timing, and success/failure status. + * This audit trail enables steering handlers to make informed guidance decisions + * based on agent behavior patterns and history. + */ + +import { AfterToolCallEvent, BeforeToolCallEvent } from '../../../hooks/events.js' +import type { LocalAgent } from '../../../types/agent.js' +import type { ToolResultStatus } from '../../../tools/types.js' +import type { JSONValue } from '../../../types/json.js' +import type { SteeringContextData, SteeringContextProvider } from './context-provider.js' + +/** + * A single entry in the tool call ledger. + */ +interface LedgerToolCall { + /** Tool input arguments. */ + args: JSONValue + /** When the tool finished executing. */ + endTime?: string + /** Error message if the tool failed. */ + error?: string | null + /** Unique tool use identifier. */ + id: string + /** Tool name. */ + name: string + /** Tool execution result. */ + result?: JSONValue + /** When the tool call was initiated. */ + startTime: string + /** Current execution state: pending while in-flight, then the underlying {@link ToolResultStatus}. */ + status: 'pending' | ToolResultStatus +} + +/** + * Configuration for {@link ToolLedgerProvider}. + */ +export interface ToolLedgerProviderConfig { + /** Maximum number of tool calls to retain. Older entries are dropped. Defaults to 100. */ + maxEntries?: number + /** Identifier for this provider instance. Defaults to `'strands:steering:toolLedger'`. */ + name?: string +} + +/** + * Context provider that tracks tool call history within a single invocation. + * + * Records every tool invocation with inputs, execution time, and success/failure status. + * The ledger is available to steering handlers for pattern detection + * (e.g., repeated failures, excessive retries). + * + * When the ledger exceeds maxEntries, the oldest entries are dropped. + * + * @example + * ```typescript + * const handler = new LLMSteeringHandler({ + * systemPrompt: '...', + * contextProviders: [new ToolLedgerProvider()], + * }) + * ``` + */ +export class ToolLedgerProvider implements SteeringContextProvider { + readonly name: string + private readonly _maxEntries: number = 100 + private readonly _toolCalls: LedgerToolCall[] = [] + + constructor(config?: ToolLedgerProviderConfig) { + this.name = config?.name ?? 'strands:steering:toolLedger' + if (config?.maxEntries !== undefined) { + this._maxEntries = config.maxEntries + } + } + + observeAgent(agent: LocalAgent): void { + agent.addHook(BeforeToolCallEvent, (event) => this._onBeforeToolCall(event)) + agent.addHook(AfterToolCallEvent, (event) => this._onAfterToolCall(event)) + } + + private _onBeforeToolCall(event: BeforeToolCallEvent): void { + this._toolCalls.push({ + startTime: new Date().toISOString(), + id: event.toolUse.toolUseId, + name: event.toolUse.name, + args: event.toolUse.input, + status: 'pending', + }) + if (this._toolCalls.length > this._maxEntries) { + this._toolCalls.splice(0, this._toolCalls.length - this._maxEntries) + } + } + + private _onAfterToolCall(event: AfterToolCallEvent): void { + const toolUseId = event.toolUse.toolUseId + for (let i = this._toolCalls.length - 1; i >= 0; i--) { + const call = this._toolCalls[i] + if (call?.id === toolUseId) { + call.endTime = new Date().toISOString() + call.status = event.result.status + call.result = event.result.content.map((block) => block.toJSON()) as JSONValue + call.error = event.error ? event.error.message : null + break + } + } + } + + /** + * Return the current ledger snapshot. + */ + get context(): SteeringContextData { + return { + type: 'toolLedger', + calls: this._toolCalls as unknown as JSONValue, + } + } +} diff --git a/strands-ts/src/vended-plugins/context-offloader/__tests__/plugin.test.ts b/strands-ts/src/vended-plugins/context-offloader/__tests__/plugin.test.ts new file mode 100644 index 0000000000..e86764547d --- /dev/null +++ b/strands-ts/src/vended-plugins/context-offloader/__tests__/plugin.test.ts @@ -0,0 +1,611 @@ +import { describe, it, expect, vi } from 'vitest' +import { ContextOffloader } from '../plugin.js' +import { InMemoryStorage } from '../storage.js' +import { AfterToolCallEvent } from '../../../hooks/events.js' +import { TextBlock, JsonBlock, ToolResultBlock } from '../../../types/messages.js' +import { ImageBlock, VideoBlock, DocumentBlock } from '../../../types/media.js' +import { createMockAgent, invokeTrackedHook } from '../../../__fixtures__/agent-helpers.js' +import { MockMessageModel } from '../../../__fixtures__/mock-message-model.js' + +const mockModel = new MockMessageModel() + +function makeMockAgent() { + return createMockAgent({ extra: { model: mockModel } as never }) +} + +function makeEvent( + content: InstanceType< + typeof TextBlock | typeof JsonBlock | typeof ImageBlock | typeof VideoBlock | typeof DocumentBlock + >[], + overrides?: { status?: 'success' | 'error'; toolName?: string } +) { + const agent = makeMockAgent() + const result = new ToolResultBlock({ + toolUseId: 'tool-123', + status: overrides?.status ?? 'success', + content, + }) + return new AfterToolCallEvent({ + agent, + toolUse: { name: overrides?.toolName ?? 'some_tool', toolUseId: 'tool-123', input: {} }, + tool: undefined, + result, + invocationState: {}, + }) +} + +describe('ContextOffloader', () => { + describe('constructor validation', () => { + it('throws if maxResultTokens is not positive', () => { + expect(() => new ContextOffloader({ storage: new InMemoryStorage(), maxResultTokens: 0 })).toThrow( + 'maxResultTokens must be positive' + ) + }) + + it('throws if previewTokens is negative', () => { + expect(() => new ContextOffloader({ storage: new InMemoryStorage(), previewTokens: -1 })).toThrow( + 'previewTokens must be non-negative' + ) + }) + + it('throws if previewTokens >= maxResultTokens', () => { + expect( + () => new ContextOffloader({ storage: new InMemoryStorage(), maxResultTokens: 100, previewTokens: 100 }) + ).toThrow('previewTokens must be less than maxResultTokens') + }) + }) + + describe('plugin interface', () => { + it('has correct name', () => { + const plugin = new ContextOffloader({ storage: new InMemoryStorage() }) + expect(plugin.name).toBe('strands:context-offloader') + }) + + it('registers AfterToolCallEvent hook', () => { + const plugin = new ContextOffloader({ storage: new InMemoryStorage() }) + const agent = createMockAgent() + plugin.initAgent(agent) + expect(agent.trackedHooks).toHaveLength(1) + expect(agent.trackedHooks[0]!.eventType).toBe(AfterToolCallEvent) + }) + + it('returns retrieval tool by default', () => { + const plugin = new ContextOffloader({ storage: new InMemoryStorage() }) + const tools = plugin.getTools() + expect(tools).toHaveLength(1) + expect(tools[0]!.name).toBe('retrieve_offloaded_content') + }) + + it('returns empty tools when includeRetrievalTool is false', () => { + const plugin = new ContextOffloader({ storage: new InMemoryStorage(), includeRetrievalTool: false }) + expect(plugin.getTools()).toHaveLength(0) + }) + }) + + describe('hook behavior', () => { + it('does not offload results below threshold', async () => { + const storage = new InMemoryStorage() + const plugin = new ContextOffloader({ storage, maxResultTokens: 2500 }) + const agent = createMockAgent() + plugin.initAgent(agent) + + const event = makeEvent([new TextBlock('short text')]) + await invokeTrackedHook(agent, event) + + expect(event.result.content).toHaveLength(1) + expect(event.result.content[0]).toBeInstanceOf(TextBlock) + expect((event.result.content[0] as TextBlock).text).toBe('short text') + }) + + it('does not offload error results', async () => { + const storage = new InMemoryStorage() + const plugin = new ContextOffloader({ storage, maxResultTokens: 10, previewTokens: 5 }) + const agent = createMockAgent() + plugin.initAgent(agent) + + const event = makeEvent([new TextBlock('x'.repeat(1000))], { status: 'error' }) + await invokeTrackedHook(agent, event) + + expect((event.result.content[0] as TextBlock).text).toBe('x'.repeat(1000)) + }) + + it('does not offload retrieval tool results', async () => { + const storage = new InMemoryStorage() + const plugin = new ContextOffloader({ + storage, + maxResultTokens: 10, + previewTokens: 5, + includeRetrievalTool: true, + }) + const agent = createMockAgent() + plugin.initAgent(agent) + + const event = makeEvent([new TextBlock('x'.repeat(1000))], { toolName: 'retrieve_offloaded_content' }) + await invokeTrackedHook(agent, event) + + expect((event.result.content[0] as TextBlock).text).toBe('x'.repeat(1000)) + }) + + it('offloads large text results', async () => { + const storage = new InMemoryStorage() + const plugin = new ContextOffloader({ storage, maxResultTokens: 100, previewTokens: 10 }) + const agent = createMockAgent() + plugin.initAgent(agent) + + const largeText = 'a'.repeat(2000) + const event = makeEvent([new TextBlock(largeText)]) + await invokeTrackedHook(agent, event) + + expect(event.result.content).toHaveLength(1) + const preview = (event.result.content[0] as TextBlock).text + expect(preview).toContain('[Offloaded:') + expect(preview).toContain('Tool result was offloaded') + expect(preview).toContain('[Stored references:]') + expect(preview).not.toContain(largeText) + }) + + it('offloads large JSON results', async () => { + const storage = new InMemoryStorage() + // JSON uses chars/2 heuristic, so 1000 chars of JSON ≈ 500 tokens + const plugin = new ContextOffloader({ storage, maxResultTokens: 10, previewTokens: 5 }) + const agent = createMockAgent() + plugin.initAgent(agent) + + const largeJson = { data: 'x'.repeat(1000) } + const event = makeEvent([new JsonBlock({ json: largeJson })]) + await invokeTrackedHook(agent, event) + + const preview = (event.result.content[0] as TextBlock).text + expect(preview).toContain('[Offloaded:') + expect(preview).toContain('json,') + }) + + it('offloads image blocks with placeholder', async () => { + const storage = new InMemoryStorage() + const plugin = new ContextOffloader({ storage, maxResultTokens: 10, previewTokens: 5 }) + const agent = createMockAgent() + plugin.initAgent(agent) + + const imgBytes = new Uint8Array(10000) + const event = makeEvent([ + new TextBlock('x'.repeat(1000)), + new ImageBlock({ format: 'png', source: { bytes: imgBytes } }), + ]) + await invokeTrackedHook(agent, event) + + const imageBlock = event.result.content.find((b) => b instanceof TextBlock && b.text.includes('[image:')) + expect(imageBlock).toBeDefined() + expect((imageBlock as TextBlock).text).toContain('[image: png,') + expect((imageBlock as TextBlock).text).toContain('ref:') + }) + + it('offloads document blocks with placeholder', async () => { + const storage = new InMemoryStorage() + const plugin = new ContextOffloader({ storage, maxResultTokens: 10, previewTokens: 5 }) + const agent = createMockAgent() + plugin.initAgent(agent) + + const docBytes = new Uint8Array(10000) + const event = makeEvent([ + new TextBlock('x'.repeat(1000)), + new DocumentBlock({ format: 'pdf', name: 'report.pdf', source: { bytes: docBytes } }), + ]) + await invokeTrackedHook(agent, event) + + const docBlock = event.result.content.find((b) => b instanceof TextBlock && b.text.includes('[document:')) + expect(docBlock).toBeDefined() + expect((docBlock as TextBlock).text).toContain('[document: pdf, report.pdf,') + expect((docBlock as TextBlock).text).toContain('ref:') + }) + + it('preserves original result on storage failure', async () => { + const failingStorage: InMemoryStorage = new InMemoryStorage() + vi.spyOn(failingStorage, 'store').mockImplementation(() => { + throw new Error('storage down') + }) + + const plugin = new ContextOffloader({ storage: failingStorage, maxResultTokens: 10, previewTokens: 5 }) + const agent = createMockAgent() + plugin.initAgent(agent) + + const event = makeEvent([new TextBlock('x'.repeat(1000))]) + const originalResult = event.result + await invokeTrackedHook(agent, event) + + expect(event.result).toBe(originalResult) + }) + + it('includes retrieval tool guidance when enabled', async () => { + const storage = new InMemoryStorage() + const plugin = new ContextOffloader({ + storage, + maxResultTokens: 10, + previewTokens: 5, + includeRetrievalTool: true, + }) + const agent = createMockAgent() + plugin.initAgent(agent) + + const event = makeEvent([new TextBlock('x'.repeat(1000))]) + await invokeTrackedHook(agent, event) + + const preview = (event.result.content[0] as TextBlock).text + expect(preview).toContain('retrieve_offloaded_content') + expect(preview).toContain('pattern') + expect(preview).toContain('line_range') + }) + + it('respects custom previewTokens', async () => { + const storage = new InMemoryStorage() + const plugin = new ContextOffloader({ storage, maxResultTokens: 10, previewTokens: 2 }) + const agent = createMockAgent() + plugin.initAgent(agent) + + const event = makeEvent([new TextBlock('a'.repeat(1000))]) + await invokeTrackedHook(agent, event) + + const preview = (event.result.content[0] as TextBlock).text + const previewSection = preview.split('[Stored references:]')[0] + // previewTokens=2 → 2*4=8 chars of 'a' in preview + expect(previewSection).toContain('a'.repeat(8)) + expect(previewSection).not.toContain('a'.repeat(100)) + }) + + it('stores and retrieves content round-trip', async () => { + const storage = new InMemoryStorage() + const plugin = new ContextOffloader({ + storage, + maxResultTokens: 10, + previewTokens: 5, + includeRetrievalTool: true, + }) + const agent = createMockAgent() + plugin.initAgent(agent) + + const event = makeEvent([new TextBlock('hello world '.repeat(100))]) + await invokeTrackedHook(agent, event) + + const preview = (event.result.content[0] as TextBlock).text + const refMatch = preview.match(/mem_\d+_tool-123_0/) + expect(refMatch).not.toBeNull() + + const retrieved = await storage.retrieve(refMatch![0]) + expect(new TextDecoder().decode(retrieved.content)).toBe('hello world '.repeat(100)) + }) + }) + + describe('retrieval tool', () => { + it('retrieves text content as string', async () => { + const storage = new InMemoryStorage() + const ref = await storage.store('k1', new TextEncoder().encode('hello'), 'text/plain') + + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const tools = plugin.getTools() + const retrievalTool = tools[0]! + const result = await (retrievalTool as unknown as { invoke(input: unknown): Promise }).invoke({ + reference: ref, + }) + expect(result).toBe('hello') + }) + + it('retrieves JSON content as parsed object', async () => { + const storage = new InMemoryStorage() + const ref = await storage.store('k1', new TextEncoder().encode('{"foo":"bar"}'), 'application/json') + + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const tools = plugin.getTools() + const retrievalTool = tools[0]! + const result = await (retrievalTool as unknown as { invoke(input: unknown): Promise }).invoke({ + reference: ref, + }) + expect(result).toEqual({ foo: 'bar' }) + }) + + it('retrieves image content as ImageBlock', async () => { + const storage = new InMemoryStorage() + const imgBytes = new Uint8Array([137, 80, 78, 71]) + const ref = await storage.store('k1', imgBytes, 'image/png') + + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const tools = plugin.getTools() + const retrievalTool = tools[0]! + const result = await (retrievalTool as unknown as { invoke(input: unknown): Promise }).invoke({ + reference: ref, + }) + expect(result).toBeInstanceOf(ImageBlock) + expect((result as ImageBlock).format).toBe('png') + }) + + it('retrieves video content as VideoBlock', async () => { + const storage = new InMemoryStorage() + const vidBytes = new Uint8Array([0x00, 0x00, 0x00, 0x1c]) + const ref = await storage.store('k1', vidBytes, 'video/mp4') + + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const tools = plugin.getTools() + const retrievalTool = tools[0]! + const result = await (retrievalTool as unknown as { invoke(input: unknown): Promise }).invoke({ + reference: ref, + }) + expect(result).toBeInstanceOf(VideoBlock) + expect((result as VideoBlock).format).toBe('mp4') + }) + + it('retrieves document content as DocumentBlock', async () => { + const storage = new InMemoryStorage() + const docBytes = new Uint8Array([0x25, 0x50, 0x44, 0x46]) + const ref = await storage.store('k1', docBytes, 'application/pdf') + + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const tools = plugin.getTools() + const retrievalTool = tools[0]! + const result = await (retrievalTool as unknown as { invoke(input: unknown): Promise }).invoke({ + reference: ref, + }) + expect(result).toBeInstanceOf(DocumentBlock) + expect((result as DocumentBlock).format).toBe('pdf') + }) + + it('returns error string for missing reference', async () => { + const storage = new InMemoryStorage() + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const tools = plugin.getTools() + const retrievalTool = tools[0]! + const result = await (retrievalTool as unknown as { invoke(input: unknown): Promise }).invoke({ + reference: 'nonexistent', + }) + expect(result).toContain('Error: reference not found') + }) + }) + + describe('search via retrieval tool', () => { + function getRetrievalTool(plugin: ContextOffloader) { + const tools = plugin.getTools() + return tools[0]! as unknown as { invoke(input: unknown): Promise } + } + + it('finds matching lines with context', async () => { + const storage = new InMemoryStorage() + const content = Array.from({ length: 20 }, (_, i) => `line ${i + 1}`).join('\n') + const ref = await storage.store('k1', new TextEncoder().encode(content), 'text/plain') + + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const result = (await getRetrievalTool(plugin).invoke({ + reference: ref, + pattern: 'line 10', + context_lines: 2, + })) as string + + expect(result).toContain('1 match for /line 10/') + expect(result).toContain('> 10| line 10') + expect(result).toContain(' 8| line 8') + expect(result).toContain(' 12| line 12') + }) + + it('returns line range without pattern', async () => { + const storage = new InMemoryStorage() + const content = Array.from({ length: 50 }, (_, i) => `line ${i + 1}`).join('\n') + const ref = await storage.store('k1', new TextEncoder().encode(content), 'text/plain') + + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const result = (await getRetrievalTool(plugin).invoke({ + reference: ref, + line_range: { start: 5, end: 10 }, + })) as string + + expect(result).toContain('[Lines 5-10 of 50]') + expect(result).toContain(' 5| line 5') + expect(result).toContain(' 10| line 10') + expect(result).not.toContain('line 4') + expect(result).not.toContain('line 11') + }) + + it('searches within line range when both provided', async () => { + const storage = new InMemoryStorage() + const content = Array.from({ length: 30 }, (_, i) => `item ${i + 1}`).join('\n') + const ref = await storage.store('k1', new TextEncoder().encode(content), 'text/plain') + + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const result = (await getRetrievalTool(plugin).invoke({ + reference: ref, + pattern: 'item 1', + line_range: { start: 10, end: 20 }, + context_lines: 0, + })) as string + + expect(result).toContain('in lines 10-20') + expect(result).toContain('> 10| item 10') + expect(result).toContain('> 11| item 11') + expect(result).not.toContain('> 1|') + }) + + it('respects custom context_lines', async () => { + const storage = new InMemoryStorage() + const content = Array.from({ length: 20 }, (_, i) => `line ${i + 1}`).join('\n') + const ref = await storage.store('k1', new TextEncoder().encode(content), 'text/plain') + + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const result = (await getRetrievalTool(plugin).invoke({ + reference: ref, + pattern: 'line 10', + context_lines: 0, + })) as string + + expect(result).toContain('> 10| line 10') + expect(result).not.toContain('line 9') + expect(result).not.toContain('line 11') + }) + + it('returns error for binary content', async () => { + const storage = new InMemoryStorage() + const ref = await storage.store('k1', new Uint8Array([137, 80, 78, 71]), 'image/png') + + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const result = (await getRetrievalTool(plugin).invoke({ + reference: ref, + pattern: 'test', + })) as string + + expect(result).toContain('Error: cannot search binary content (image/png)') + }) + + it('falls back to literal match on invalid regex', async () => { + const storage = new InMemoryStorage() + const content = 'foo (bar\nbaz\nfoo (bar again' + const ref = await storage.store('k1', new TextEncoder().encode(content), 'text/plain') + + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const result = (await getRetrievalTool(plugin).invoke({ + reference: ref, + pattern: 'foo (bar', + context_lines: 0, + })) as string + + expect(result).toContain('2 matches') + expect(result).toContain('> 1| foo (bar') + expect(result).toContain('> 3| foo (bar again') + }) + + it('returns error for missing reference', async () => { + const storage = new InMemoryStorage() + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const result = (await getRetrievalTool(plugin).invoke({ + reference: 'nonexistent', + pattern: 'test', + })) as string + + expect(result).toContain('Error: reference not found') + }) + + it('searches JSON content', async () => { + const storage = new InMemoryStorage() + const json = JSON.stringify({ name: 'test', items: [1, 2, 3] }, null, 2) + const ref = await storage.store('k1', new TextEncoder().encode(json), 'application/json') + + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const result = (await getRetrievalTool(plugin).invoke({ + reference: ref, + pattern: 'items', + context_lines: 1, + })) as string + + expect(result).toContain('1 match for /items/') + expect(result).toContain('items') + }) + + it('reports no matches', async () => { + const storage = new InMemoryStorage() + const content = 'hello\nworld\n' + const ref = await storage.store('k1', new TextEncoder().encode(content), 'text/plain') + + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const result = (await getRetrievalTool(plugin).invoke({ + reference: ref, + pattern: 'nonexistent', + })) as string + + expect(result).toContain("No matches found for pattern 'nonexistent'") + }) + + it('truncates output when too many matches', async () => { + const storage = new InMemoryStorage() + const content = Array.from({ length: 500 }, (_, i) => `match line ${i + 1}`).join('\n') + const ref = await storage.store('k1', new TextEncoder().encode(content), 'text/plain') + + const plugin = new ContextOffloader({ + storage, + maxResultTokens: 50, + previewTokens: 10, + includeRetrievalTool: true, + }) + const result = (await getRetrievalTool(plugin).invoke({ + reference: ref, + pattern: 'match', + context_lines: 0, + })) as string + + expect(result).toContain('output truncated, narrow your search') + expect(result.length).toBeLessThan(content.length) + }) + + it('merges overlapping context into single group', async () => { + const storage = new InMemoryStorage() + const content = Array.from({ length: 10 }, (_, i) => `line ${i + 1}`).join('\n') + const ref = await storage.store('k1', new TextEncoder().encode(content), 'text/plain') + + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const result = (await getRetrievalTool(plugin).invoke({ + reference: ref, + pattern: 'line [45]', + context_lines: 2, + })) as string + + expect(result).toContain('2 matches') + // Lines 4 and 5 are adjacent — with context_lines=2 they should merge into one group + expect(result).not.toContain('---') + }) + + it('returns error when line_range.start > totalLines', async () => { + const storage = new InMemoryStorage() + const content = 'line 1\nline 2\nline 3' + const ref = await storage.store('k1', new TextEncoder().encode(content), 'text/plain') + + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const result = (await getRetrievalTool(plugin).invoke({ + reference: ref, + line_range: { start: 100, end: 200 }, + })) as string + + expect(result).toContain('beyond content length (3 lines)') + }) + + it('clamps line_range.end to actual content length', async () => { + const storage = new InMemoryStorage() + const content = 'line 1\nline 2\nline 3' + const ref = await storage.store('k1', new TextEncoder().encode(content), 'text/plain') + + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const result = (await getRetrievalTool(plugin).invoke({ + reference: ref, + line_range: { start: 2, end: 100 }, + })) as string + + expect(result).toContain('[Lines 2-3 of 3]') + expect(result).toContain('line 2') + expect(result).toContain('line 3') + }) + + it('returns first N lines when only context_lines is provided', async () => { + const storage = new InMemoryStorage() + const content = Array.from({ length: 20 }, (_, i) => `line ${i + 1}`).join('\n') + const ref = await storage.store('k1', new TextEncoder().encode(content), 'text/plain') + + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const result = (await getRetrievalTool(plugin).invoke({ + reference: ref, + context_lines: 10, + })) as string + + expect(result).toContain('[Lines 1-10 of 20]') + expect(result).toContain('line 1') + expect(result).toContain('line 10') + expect(result).not.toContain('line 11') + }) + + it('returns first line when context_lines is 0', async () => { + const storage = new InMemoryStorage() + const content = Array.from({ length: 10 }, (_, i) => `line ${i + 1}`).join('\n') + const ref = await storage.store('k1', new TextEncoder().encode(content), 'text/plain') + + const plugin = new ContextOffloader({ storage, includeRetrievalTool: true }) + const result = (await getRetrievalTool(plugin).invoke({ + reference: ref, + context_lines: 0, + })) as string + + expect(result).toContain('[Lines 1-1 of 10]') + expect(result).toContain('line 1') + expect(result).not.toContain('line 2') + }) + }) +}) diff --git a/strands-ts/src/vended-plugins/context-offloader/__tests__/search.test.ts b/strands-ts/src/vended-plugins/context-offloader/__tests__/search.test.ts new file mode 100644 index 0000000000..c4a81d7fae --- /dev/null +++ b/strands-ts/src/vended-plugins/context-offloader/__tests__/search.test.ts @@ -0,0 +1,187 @@ +import { describe, it, expect } from 'vitest' +import { searchContent, isSearchableContent } from '../search.js' + +describe('isSearchableContent', () => { + it('returns true for text/* types', () => { + expect(isSearchableContent('text/plain')).toBe(true) + expect(isSearchableContent('text/html')).toBe(true) + }) + + it('returns true for application/json', () => { + expect(isSearchableContent('application/json')).toBe(true) + }) + + it('returns false for binary types', () => { + expect(isSearchableContent('image/png')).toBe(false) + expect(isSearchableContent('video/mp4')).toBe(false) + expect(isSearchableContent('application/pdf')).toBe(false) + }) +}) + +describe('searchContent', () => { + const maxChars = 10_000 + + describe('empty content', () => { + it('returns empty message for empty string', () => { + expect(searchContent('', { context_lines: 5, pattern: 'x' }, maxChars)).toBe('Content is empty (0 lines).') + }) + + it('returns empty message for single empty line', () => { + expect(searchContent('\n', { context_lines: 5, line_range: { start: 1, end: 1 } }, maxChars)).not.toContain( + 'Content is empty' + ) + }) + }) + + describe('line_range validation', () => { + const text = 'line 1\nline 2\nline 3\nline 4\nline 5' + + it('returns error when start > end', () => { + const result = searchContent(text, { context_lines: 5, line_range: { start: 5, end: 2 } }, maxChars) + expect(result).toContain('must be <= line_range.end') + }) + + it('returns error when start > total lines', () => { + const result = searchContent(text, { context_lines: 5, line_range: { start: 100, end: 200 } }, maxChars) + expect(result).toContain('beyond content length (5 lines)') + }) + + it('clamps end to total lines', () => { + const result = searchContent(text, { context_lines: 5, line_range: { start: 3, end: 999 } }, maxChars) + expect(result).toContain('[Lines 3-5 of 5]') + expect(result).toContain('line 3') + expect(result).toContain('line 5') + }) + }) + + describe('pattern search', () => { + const text = Array.from({ length: 20 }, (_, i) => `line ${i + 1}`).join('\n') + + it('finds a single match with context', () => { + const result = searchContent(text, { pattern: 'line 10', context_lines: 2 }, maxChars) + expect(result).toContain('[1 match for /line 10/]') + expect(result).toContain('> 10| line 10') + expect(result).toContain(' 8| line 8') + expect(result).toContain(' 12| line 12') + expect(result).not.toContain('line 7') + }) + + it('finds multiple matches', () => { + const result = searchContent(text, { pattern: 'line [12]0', context_lines: 0 }, maxChars) + expect(result).toContain('2 matches') + expect(result).toContain('> 10| line 10') + expect(result).toContain('> 20| line 20') + }) + + it('returns no-match message when pattern not found', () => { + const result = searchContent(text, { pattern: 'nonexistent', context_lines: 5 }, maxChars) + expect(result).toContain("No matches found for pattern 'nonexistent'") + expect(result).toContain('searched 20 lines') + }) + + it('uses context_lines: 0 for no context', () => { + const result = searchContent(text, { pattern: 'line 5', context_lines: 0 }, maxChars) + expect(result).toContain('> 5| line 5') + expect(result).not.toContain('line 4') + expect(result).not.toContain('line 6') + }) + + it('merges overlapping context into one group', () => { + const result = searchContent(text, { pattern: 'line [67]', context_lines: 2 }, maxChars) + expect(result).toContain('2 matches') + expect(result).not.toContain('---') + }) + + it('separates non-overlapping groups with ---', () => { + const result = searchContent(text, { pattern: 'line (1|20)', context_lines: 0 }, maxChars) + expect(result).toContain('---') + }) + + it('falls back to literal match on invalid regex', () => { + const text = 'foo (bar\nbaz\nfoo (bar again' + const result = searchContent(text, { pattern: 'foo (bar', context_lines: 0 }, maxChars) + expect(result).toContain('2 matches') + expect(result).toContain('> 1| foo (bar') + expect(result).toContain('> 3| foo (bar again') + }) + + it('sanitizes pattern in header', () => { + const text = 'test line\nanother line' + const result = searchContent(text, { pattern: 'test]\nline', context_lines: 0 }, maxChars) + // The header should not contain raw ] or newlines + const header = result.split('\n')[0]! + expect(header).not.toContain(']/') + expect(header).not.toContain('\n') + }) + }) + + describe('pattern search with line_range', () => { + const text = Array.from({ length: 30 }, (_, i) => `item ${i + 1}`).join('\n') + + it('searches only within the specified range', () => { + const result = searchContent( + text, + { pattern: 'item 1', line_range: { start: 10, end: 20 }, context_lines: 0 }, + maxChars + ) + expect(result).toContain('in lines 10-20') + expect(result).toContain('> 10| item 10') + expect(result).toContain('> 11| item 11') + expect(result).not.toContain('> 1|') + }) + + it('reports no matches within range', () => { + const result = searchContent( + text, + { pattern: 'item 5', line_range: { start: 10, end: 20 }, context_lines: 0 }, + maxChars + ) + expect(result).toContain('No matches found') + expect(result).toContain('in lines 10-20') + }) + }) + + describe('line range (no pattern)', () => { + const text = Array.from({ length: 50 }, (_, i) => `line ${i + 1}`).join('\n') + + it('returns specified range with header', () => { + const result = searchContent(text, { line_range: { start: 5, end: 10 }, context_lines: 5 }, maxChars) + expect(result).toContain('[Lines 5-10 of 50]') + expect(result).toContain(' 5| line 5') + expect(result).toContain(' 10| line 10') + }) + + it('does not show lines outside range', () => { + const result = searchContent(text, { line_range: { start: 5, end: 10 }, context_lines: 5 }, maxChars) + expect(result).not.toContain('line 4') + expect(result).not.toContain('line 11') + }) + + it('does not include --- separators for contiguous lines', () => { + const result = searchContent(text, { line_range: { start: 1, end: 10 }, context_lines: 5 }, maxChars) + expect(result).not.toContain('---') + }) + }) + + describe('truncation', () => { + it('truncates pattern results when output exceeds maxChars', () => { + const text = Array.from({ length: 500 }, (_, i) => `match line ${i + 1}`).join('\n') + const result = searchContent(text, { pattern: 'match', context_lines: 0 }, 200) + expect(result).toContain('output truncated, narrow your search') + expect(result.length).toBeLessThanOrEqual(250) // 200 + truncation message + }) + + it('truncates line range results when output exceeds maxChars', () => { + const text = Array.from({ length: 500 }, (_, i) => `line ${i + 1}`).join('\n') + const result = searchContent(text, { line_range: { start: 1, end: 500 }, context_lines: 5 }, 200) + expect(result).toContain('output truncated, narrow your range') + expect(result.length).toBeLessThanOrEqual(250) + }) + + it('does not truncate when output fits within maxChars', () => { + const text = 'short\ncontent' + const result = searchContent(text, { line_range: { start: 1, end: 2 }, context_lines: 5 }, maxChars) + expect(result).not.toContain('truncated') + }) + }) +}) diff --git a/strands-ts/src/vended-plugins/context-offloader/__tests__/storage.test.node.ts b/strands-ts/src/vended-plugins/context-offloader/__tests__/storage.test.node.ts new file mode 100644 index 0000000000..3ca67f7662 --- /dev/null +++ b/strands-ts/src/vended-plugins/context-offloader/__tests__/storage.test.node.ts @@ -0,0 +1,97 @@ +import { describe, it, expect, beforeEach, afterEach } from 'vitest' +import { FileStorage } from '../storage.js' +import * as fs from 'node:fs/promises' +import * as path from 'node:path' +import * as os from 'node:os' + +describe('FileStorage', () => { + let tmpDir: string + + beforeEach(async () => { + tmpDir = await fs.mkdtemp(path.join(os.tmpdir(), 'context-offloader-test-')) + }) + + afterEach(async () => { + await fs.rm(tmpDir, { recursive: true, force: true }) + }) + + it('stores and retrieves text content', async () => { + const storage = new FileStorage(tmpDir) + const content = new TextEncoder().encode('hello world') + const ref = await storage.store('key1', content, 'text/plain') + + const result = await storage.retrieve(ref) + expect(new TextDecoder().decode(result.content)).toBe('hello world') + expect(result.contentType).toBe('text/plain') + }) + + it('stores and retrieves binary content', async () => { + const storage = new FileStorage(tmpDir) + const content = new Uint8Array([1, 2, 3, 4, 5]) + const ref = await storage.store('key1', content, 'image/png') + + const result = await storage.retrieve(ref) + expect(result.content).toEqual(content) + expect(result.contentType).toBe('image/png') + }) + + it('returns file path as reference preserving configured directory', async () => { + const storage = new FileStorage(tmpDir) + const content = new TextEncoder().encode('test') + const ref = await storage.store('k1', content, 'text/plain') + + expect(ref.startsWith(tmpDir)).toBe(true) + expect(ref).toMatch(/\.txt$/) + }) + + it('uses correct file extensions', async () => { + const storage = new FileStorage(tmpDir) + const content = new TextEncoder().encode('test') + + const txtRef = await storage.store('k1', content, 'text/plain') + expect(txtRef).toMatch(/\.txt$/) + + const jsonRef = await storage.store('k2', content, 'application/json') + expect(jsonRef).toMatch(/\.json$/) + + const pngRef = await storage.store('k3', content, 'image/png') + expect(pngRef).toMatch(/\.png$/) + }) + + it('throws on missing reference', async () => { + const storage = new FileStorage(tmpDir) + await expect(storage.retrieve(path.join(tmpDir, 'nonexistent.txt'))).rejects.toThrow('Reference not found') + }) + + it('sanitizes keys for safe filenames', async () => { + const storage = new FileStorage(tmpDir) + const content = new TextEncoder().encode('test') + const ref = await storage.store('../../../etc/passwd', content, 'text/plain') + expect(ref).not.toContain('..') + }) + + it('prevents path traversal on retrieve', async () => { + const storage = new FileStorage(tmpDir) + await expect(storage.retrieve('../../etc/passwd')).rejects.toThrow('Reference not found') + }) + + it('creates artifact directory if it does not exist', async () => { + const nestedDir = path.join(tmpDir, 'nested', 'dir') + const storage = new FileStorage(nestedDir) + const content = new TextEncoder().encode('test') + await storage.store('key1', content, 'text/plain') + + const stat = await fs.stat(nestedDir) + expect(stat.isDirectory()).toBe(true) + }) + + it('persists metadata across instances', async () => { + const storage1 = new FileStorage(tmpDir) + const content = new TextEncoder().encode('test') + const ref = await storage1.store('key1', content, 'application/json') + + const storage2 = new FileStorage(tmpDir) + const result = await storage2.retrieve(ref) + expect(result.contentType).toBe('application/json') + }) +}) diff --git a/strands-ts/src/vended-plugins/context-offloader/__tests__/storage.test.ts b/strands-ts/src/vended-plugins/context-offloader/__tests__/storage.test.ts new file mode 100644 index 0000000000..e442055828 --- /dev/null +++ b/strands-ts/src/vended-plugins/context-offloader/__tests__/storage.test.ts @@ -0,0 +1,190 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { InMemoryStorage, S3Storage } from '../storage.js' + +describe('InMemoryStorage', () => { + it('stores and retrieves text content', async () => { + const storage = new InMemoryStorage() + const content = new TextEncoder().encode('hello world') + const ref = await storage.store('key1', content, 'text/plain') + + const result = await storage.retrieve(ref) + expect(new TextDecoder().decode(result.content)).toBe('hello world') + expect(result.contentType).toBe('text/plain') + }) + + it('stores and retrieves binary content', async () => { + const storage = new InMemoryStorage() + const content = new Uint8Array([1, 2, 3, 4, 5]) + const ref = await storage.store('key1', content, 'image/png') + + const result = await storage.retrieve(ref) + expect(result.content).toEqual(content) + expect(result.contentType).toBe('image/png') + }) + + it('generates unique references', async () => { + const storage = new InMemoryStorage() + const content = new TextEncoder().encode('test') + const ref1 = await storage.store('key1', content) + const ref2 = await storage.store('key2', content) + expect(ref1).not.toBe(ref2) + }) + + it('uses mem_ prefix in references', async () => { + const storage = new InMemoryStorage() + const ref = await storage.store('mykey', new TextEncoder().encode('test')) + expect(ref).toMatch(/^mem_\d+_mykey$/) + }) + + it('throws on missing reference', async () => { + const storage = new InMemoryStorage() + await expect(storage.retrieve('nonexistent')).rejects.toThrow('Reference not found: nonexistent') + }) + + it('clears all stored content', async () => { + const storage = new InMemoryStorage() + const ref = await storage.store('key1', new TextEncoder().encode('test')) + storage.clear() + await expect(storage.retrieve(ref)).rejects.toThrow('Reference not found') + }) + + it('defaults content type to text/plain', async () => { + const storage = new InMemoryStorage() + const ref = await storage.store('key1', new TextEncoder().encode('test')) + const result = await storage.retrieve(ref) + expect(result.contentType).toBe('text/plain') + }) +}) + +describe('S3Storage', () => { + let mockSend: ReturnType + let mockS3Client: { send: ReturnType } + + beforeEach(() => { + mockSend = vi.fn() + mockS3Client = { send: mockSend } + }) + + describe('store', () => { + it('returns s3:// URI as reference', async () => { + mockSend.mockResolvedValue({}) + const storage = new S3Storage('my-bucket', { s3Client: mockS3Client as never }) + + const ref = await storage.store('key1', new TextEncoder().encode('test'), 'text/plain') + + expect(ref).toMatch(/^s3:\/\/my-bucket\//) + expect(ref).toContain('key1') + }) + + it('includes prefix in s3 key', async () => { + mockSend.mockResolvedValue({}) + const storage = new S3Storage('my-bucket', { prefix: 'artifacts', s3Client: mockS3Client as never }) + + const ref = await storage.store('key1', new TextEncoder().encode('test')) + + expect(ref).toMatch(/^s3:\/\/my-bucket\/artifacts\//) + }) + + it('normalizes trailing slashes on prefix', async () => { + mockSend.mockResolvedValue({}) + const storage = new S3Storage('b', { prefix: 'p///', s3Client: mockS3Client as never }) + + const ref = await storage.store('k', new TextEncoder().encode('x')) + + expect(ref).toMatch(/^s3:\/\/b\/p\//) + // Check no double slashes in the path portion (after s3://) + const pathPortion = ref.replace('s3://', '') + expect(pathPortion).not.toContain('//') + }) + + it('sends correct PutObject params', async () => { + mockSend.mockResolvedValue({}) + const storage = new S3Storage('my-bucket', { s3Client: mockS3Client as never }) + const content = new TextEncoder().encode('hello') + + await storage.store('key1', content, 'application/json') + + expect(mockSend).toHaveBeenCalledOnce() + const command = mockSend.mock.calls[0]![0] + expect(command.input.Bucket).toBe('my-bucket') + expect(command.input.Body).toBe(content) + expect(command.input.ContentType).toBe('application/json') + }) + + it('sanitizes keys', async () => { + mockSend.mockResolvedValue({}) + const storage = new S3Storage('b', { s3Client: mockS3Client as never }) + + const ref = await storage.store('../../etc/passwd', new TextEncoder().encode('x')) + + expect(ref).not.toContain('..') + expect(ref).not.toContain('etc/passwd') + }) + }) + + describe('retrieve', () => { + it('retrieves content by s3:// URI', async () => { + mockSend.mockResolvedValueOnce({}).mockResolvedValueOnce({ + Body: { transformToByteArray: () => Promise.resolve(new Uint8Array([1, 2, 3])) }, + ContentType: 'image/png', + }) + + const storage = new S3Storage('my-bucket', { s3Client: mockS3Client as never }) + const ref = await storage.store('key1', new Uint8Array([1, 2, 3]), 'image/png') + const result = await storage.retrieve(ref) + + expect(result.content).toEqual(new Uint8Array([1, 2, 3])) + expect(result.contentType).toBe('image/png') + }) + + it('retrieves content by raw key', async () => { + mockSend.mockResolvedValue({ + Body: { transformToByteArray: () => Promise.resolve(new TextEncoder().encode('hello')) }, + ContentType: 'text/plain', + }) + + const storage = new S3Storage('b', { s3Client: mockS3Client as never }) + const result = await storage.retrieve('some/raw/key') + + expect(new TextDecoder().decode(result.content)).toBe('hello') + const command = mockSend.mock.calls[0]![0] + expect(command.input.Key).toBe('some/raw/key') + }) + + it('throws on bucket mismatch', async () => { + const storage = new S3Storage('my-bucket', { s3Client: mockS3Client as never }) + + await expect(storage.retrieve('s3://wrong-bucket/key')).rejects.toThrow('bucket mismatch') + }) + + it('throws on NoSuchKey error', async () => { + const noSuchKey = new Error('not found') + noSuchKey.name = 'NoSuchKey' + mockSend.mockRejectedValue(noSuchKey) + + const storage = new S3Storage('b', { s3Client: mockS3Client as never }) + + await expect(storage.retrieve('missing-key')).rejects.toThrow('Reference not found') + }) + + it('defaults contentType to application/octet-stream when missing', async () => { + mockSend.mockResolvedValue({ + Body: { transformToByteArray: () => Promise.resolve(new Uint8Array([0])) }, + }) + + const storage = new S3Storage('b', { s3Client: mockS3Client as never }) + const result = await storage.retrieve('key') + + expect(result.contentType).toBe('application/octet-stream') + }) + + it('rethrows non-NoSuchKey errors', async () => { + const networkError = new Error('network timeout') + mockSend.mockRejectedValue(networkError) + + const storage = new S3Storage('b', { s3Client: mockS3Client as never }) + + await expect(storage.retrieve('key')).rejects.toThrow('network timeout') + }) + }) +}) diff --git a/strands-ts/src/vended-plugins/context-offloader/index.ts b/strands-ts/src/vended-plugins/context-offloader/index.ts new file mode 100644 index 0000000000..21c4fc9023 --- /dev/null +++ b/strands-ts/src/vended-plugins/context-offloader/index.ts @@ -0,0 +1,23 @@ +/** + * Context offloading plugin for Strands Agents. + * + * This module provides the ContextOffloader plugin and Storage backends for + * automatically offloading oversized tool results to external storage, replacing + * them with truncated previews and actionable storage references. + * + * @example + * ```typescript + * import { Agent } from '@strands-agents/sdk' + * import { ContextOffloader, InMemoryStorage } from '@strands-agents/sdk/vended-plugins/context-offloader' + * + * const agent = new Agent({ + * model, + * plugins: [new ContextOffloader({ storage: new InMemoryStorage() })], + * }) + * ``` + */ + +export { ContextOffloader } from './plugin.js' +export type { ContextOffloaderConfig } from './plugin.js' +export type { Storage } from './storage.js' +export { InMemoryStorage, FileStorage, S3Storage } from './storage.js' diff --git a/strands-ts/src/vended-plugins/context-offloader/plugin.ts b/strands-ts/src/vended-plugins/context-offloader/plugin.ts new file mode 100644 index 0000000000..1c19af01bb --- /dev/null +++ b/strands-ts/src/vended-plugins/context-offloader/plugin.ts @@ -0,0 +1,343 @@ +import type { Plugin } from '../../plugins/plugin.js' +import type { Tool } from '../../tools/tool.js' +import type { LocalAgent } from '../../types/agent.js' +import { AfterToolCallEvent } from '../../hooks/events.js' +import { TextBlock, JsonBlock, ToolResultBlock, Message } from '../../types/messages.js' +import type { ToolResultContent } from '../../types/messages.js' +import { ImageBlock, VideoBlock, DocumentBlock } from '../../types/media.js' +import type { ImageFormat, VideoFormat, DocumentFormat } from '../../types/media.js' +import { tool } from '../../tools/tool-factory.js' +import { z } from 'zod' +import { logger } from '../../logging/logger.js' +import type { JSONValue } from '../../types/json.js' +import type { Storage } from './storage.js' +import { isSearchableContent, searchContent } from './search.js' + +const CHARS_PER_TOKEN = 4 +const DEFAULT_MAX_RESULT_TOKENS = 2_500 +const DEFAULT_PREVIEW_TOKENS = 1_000 +const RETRIEVAL_TOOL_NAME = 'retrieve_offloaded_content' + +const retrievalInputSchema = z.object({ + reference: z.string().describe('The reference string from the offload placeholder (e.g. "mem_1_tool-123_0").'), + pattern: z + .string() + .optional() + .describe('Regex or keyword to grep for. Returns only matching lines with context — not the full content.'), + line_range: z + .object({ + start: z.number().int().min(1).describe('First line to return (1-indexed).'), + end: z.number().int().min(1).describe('Last line to return (1-indexed).'), + }) + .optional() + .describe('Return only this span of lines. Combine with pattern to search within the range.'), + context_lines: z + .number() + .int() + .min(0) + .optional() + .describe( + 'Lines before AND after each match (like grep -C). Default: 5. Without pattern/line_range, returns first N lines.' + ), +}) + +function slicePreview(text: string, previewTokens: number): string { + const maxChars = previewTokens * CHARS_PER_TOKEN + if (text.length <= maxChars) return text + return text.slice(0, maxChars) +} + +function getBytes(block: ToolResultContent): Uint8Array | undefined { + if (block instanceof ImageBlock && block.source.type === 'imageSourceBytes') { + return block.source.bytes + } + if (block instanceof VideoBlock && block.source.type === 'videoSourceBytes') { + return block.source.bytes + } + if (block instanceof DocumentBlock) { + if (block.source.type === 'documentSourceBytes') return block.source.bytes + if (block.source.type === 'documentSourceText') return new TextEncoder().encode(block.source.text) + } + return undefined +} + +function decodeStoredContent(content: Uint8Array, contentType: string, reference: string): JSONValue { + if (contentType.startsWith('text/')) { + return new TextDecoder().decode(content) + } + if (contentType === 'application/json') { + const text = new TextDecoder().decode(content) + try { + return JSON.parse(text) as JSONValue + } catch { + return text + } + } + // Return native content blocks for binary types so the agent sees the actual content. + // FunctionTool._wrapInToolResult passes ImageBlock/VideoBlock/DocumentBlock through as-is + // at runtime, even though the callback type signature only accepts JSONValue. + if (contentType.startsWith('image/')) { + const format = contentType.split('/').pop()! + return new ImageBlock({ + format: format as ImageFormat, + source: { bytes: content }, + }) as unknown as JSONValue + } + if (contentType.startsWith('video/')) { + const format = contentType.split('/').pop()! + return new VideoBlock({ + format: format as VideoFormat, + source: { bytes: content }, + }) as unknown as JSONValue + } + if (contentType.startsWith('application/')) { + const format = contentType.split('/').pop()! + return new DocumentBlock({ + format: format as DocumentFormat, + name: reference, + source: { bytes: content }, + }) as unknown as JSONValue + } + return new TextDecoder('utf-8', { fatal: false }).decode(content) +} + +/** Configuration for the {@link ContextOffloader} plugin. */ +export interface ContextOffloaderConfig { + /** Storage backend for persisting offloaded content. */ + storage: Storage + /** Token threshold above which tool results are offloaded. Defaults to 2,500. */ + maxResultTokens?: number + /** Number of tokens to keep as an inline preview. Defaults to 1,000. */ + previewTokens?: number + /** Whether to register the `retrieve_offloaded_content` tool. Defaults to true. */ + includeRetrievalTool?: boolean +} + +/** + * Plugin that offloads oversized tool results to reduce context consumption. + * + * When a tool result exceeds the configured token threshold, this plugin stores + * each content block to a storage backend and replaces the in-context result with + * a truncated text preview plus per-block storage references. + * + * @example + * ```typescript + * import { ContextOffloader, InMemoryStorage } from '@strands-agents/sdk/vended-plugins/context-offloader' + * + * const agent = new Agent({ + * model, + * plugins: [new ContextOffloader({ storage: new InMemoryStorage() })], + * }) + * ``` + */ +export class ContextOffloader implements Plugin { + readonly name = 'strands:context-offloader' + + private readonly _storage: Storage + private readonly _maxResultTokens: number + private readonly _previewTokens: number + private readonly _includeRetrievalTool: boolean + private _retrievalTool: Tool | undefined + + constructor(config: ContextOffloaderConfig) { + const maxResultTokens = config.maxResultTokens ?? DEFAULT_MAX_RESULT_TOKENS + const previewTokens = config.previewTokens ?? DEFAULT_PREVIEW_TOKENS + + if (maxResultTokens <= 0) throw new Error('maxResultTokens must be positive') + if (previewTokens < 0) throw new Error('previewTokens must be non-negative') + if (previewTokens >= maxResultTokens) throw new Error('previewTokens must be less than maxResultTokens') + + this._storage = config.storage + this._maxResultTokens = maxResultTokens + this._previewTokens = previewTokens + this._includeRetrievalTool = config.includeRetrievalTool ?? true + } + + initAgent(agent: LocalAgent): void { + agent.addHook(AfterToolCallEvent, (event) => this._handleToolResult(event)) + } + + getTools(): Tool[] { + if (!this._includeRetrievalTool) return [] + if (!this._retrievalTool) this._retrievalTool = this._createRetrievalTool() + return [this._retrievalTool] + } + + private _createRetrievalTool(): Tool { + const storage = this._storage + const maxChars = this._maxResultTokens * CHARS_PER_TOKEN + + return tool({ + name: RETRIEVAL_TOOL_NAME, + description: + 'When a tool result was too large to keep in context, it was stored externally and replaced with a preview and a reference. ' + + 'Use this tool with that reference to access the stored content.\n\n' + + 'Returns:\n' + + ' - With pattern: matching lines with line numbers and surrounding context\n' + + ' - With line_range: the specified span of lines with line numbers\n' + + ' - Without pattern/line_range: the full original content (use sparingly — re-injects all tokens)\n\n' + + 'Constraints:\n' + + ' - pattern/line_range/context_lines only work on text content. For binary content, omit them.\n' + + ' - Line numbers in results are 1-indexed and can be used in follow-up line_range calls.\n\n' + + 'Examples:\n' + + ' { reference: "ref_1", pattern: "error" } → lines containing "error" with 5 lines context\n' + + ' { reference: "ref_1", pattern: "error|warning", context_lines: 3 } → regex, 3 lines context\n' + + ' { reference: "ref_1", line_range: { start: 10, end: 25 } } → lines 10-25\n' + + ' { reference: "ref_1", pattern: "TODO", line_range: { start: 1, end: 50 } } → search within range', + inputSchema: retrievalInputSchema, + callback: async (input) => { + try { + const result = await storage.retrieve(input.reference) + + if (!input.pattern && !input.line_range && input.context_lines === undefined) { + return decodeStoredContent(result.content, result.contentType, input.reference) + } + + if (!isSearchableContent(result.contentType)) { + return `Error: cannot search binary content (${result.contentType}). Omit pattern/line_range/context_lines to retrieve the full content.` + } + + const text = new TextDecoder().decode(result.content) + const contextLines = input.context_lines ?? 5 + const lineRange = + input.line_range ?? (!input.pattern ? { start: 1, end: Math.max(1, contextLines) } : undefined) + + return searchContent( + text, + { pattern: input.pattern, line_range: lineRange, context_lines: contextLines }, + maxChars + ) + } catch { + return `Error: reference not found: ${input.reference}` + } + }, + }) + } + + private async _storeBlock( + block: ToolResultContent, + key: string + ): Promise<{ ref: string; contentType: string; description: string }> { + if (block instanceof TextBlock && block.text) { + const ref = await this._storage.store(key, new TextEncoder().encode(block.text), 'text/plain') + return { ref, contentType: 'text/plain', description: `text, ${block.text.length.toLocaleString()} chars` } + } + if (block instanceof JsonBlock) { + const jsonStr = JSON.stringify(block.json, null, 2) + const jsonBytes = new TextEncoder().encode(jsonStr) + const ref = await this._storage.store(key, jsonBytes, 'application/json') + return { ref, contentType: 'application/json', description: `json, ${jsonBytes.length.toLocaleString()} bytes` } + } + if (block instanceof ImageBlock || block instanceof VideoBlock || block instanceof DocumentBlock) { + const bytes = getBytes(block) + const contentType = + block instanceof ImageBlock + ? `image/${block.format}` + : block instanceof VideoBlock + ? `video/${block.format}` + : `application/${block.format}` + const label = block instanceof DocumentBlock ? block.name : contentType + if (bytes) { + const ref = await this._storage.store(key, bytes, contentType) + return { ref, contentType, description: `${label}, ${bytes.length.toLocaleString()} bytes` } + } + return { ref: '', contentType, description: `${label}, 0 bytes` } + } + logger.warn('unsupported content block type encountered during offloading, skipping') + return { ref: '', contentType: 'unknown', description: 'unknown block type' } + } + + private _buildPreviewText( + content: ToolResultContent[], + references: Array<{ ref: string; description: string }>, + tokenCount: number, + fullText: string + ): string { + const preview = fullText ? slicePreview(fullText, this._previewTokens) : '' + const refLines = references + .filter((r) => r.ref) + .map((r) => ` ${r.ref} (${r.description})`) + .join('\n') + + let guidance = + 'Tool result was offloaded to external storage due to size.\n' + + 'Use the preview below if it answers your question.\n' + if (this._includeRetrievalTool) { + guidance += + 'If you need more detail, use retrieve_offloaded_content with a reference and:\n' + + ' - pattern: regex or keyword to find matching lines with context\n' + + ' - line_range: { start, end } to read a specific span of lines\n' + + 'Retrieve full content (omit pattern/line_range) as a last resort.' + } else { + guidance += 'If you need more detail, use your available tools to access specific data.' + } + + return ( + `[Offloaded: ${content.length} blocks, ~${tokenCount.toLocaleString()} tokens]\n` + + `${guidance}\n\n` + + `${preview}\n\n` + + `[Stored references:]\n${refLines}` + ) + } + + private async _handleToolResult(event: AfterToolCallEvent): Promise { + if (event.result.status === 'error') return + + // Skip results from the retrieval tool to prevent circular offloading + if (this._includeRetrievalTool && event.toolUse.name === RETRIEVAL_TOOL_NAME) return + + const content = event.result.content + const toolUseId = event.result.toolUseId + + const tokenCount = await event.agent.model.countTokens([new Message({ role: 'user', content: [event.result] })]) + + if (tokenCount <= this._maxResultTokens) return + + // Extract text preview from text/JSON blocks + const textParts: string[] = [] + for (const block of content) { + if (block instanceof TextBlock && block.text) textParts.push(block.text) + else if (block instanceof JsonBlock) textParts.push(JSON.stringify(block.json, null, 2)) + } + const fullText = textParts.join('\n') + + // Store each content block to the storage backend + let references: Array<{ ref: string; contentType: string; description: string }> + try { + references = await Promise.all(content.map((block, i) => this._storeBlock(block, `${toolUseId}_${i}`))) + } catch (err) { + logger.warn(`tool_use_id=<${toolUseId}> | failed to offload tool result, keeping original`, err) + return + } + + logger.debug( + `tool_use_id=<${toolUseId}>, blocks=<${references.length}>, tokens=<${tokenCount}> | tool result offloaded` + ) + + // Build replacement content: preview text + media placeholders + const newContent: ToolResultContent[] = [ + new TextBlock(this._buildPreviewText(content, references, tokenCount, fullText)), + ] + for (let i = 0; i < content.length; i++) { + const block = content[i]! + const ref = references[i]?.ref ?? '' + if (block instanceof TextBlock || block instanceof JsonBlock) continue + + const bytes = getBytes(block) + const size = bytes ? bytes.length : 0 + let label: string | undefined + if (block instanceof ImageBlock) label = `image: ${block.format}` + else if (block instanceof VideoBlock) label = `video: ${block.format}` + else if (block instanceof DocumentBlock) label = `document: ${block.format}, ${block.name}` + if (label) { + newContent.push(new TextBlock(`[${label}, ${size} bytes${ref ? ` | ref: ${ref}` : ''}]`)) + } + } + + event.result = new ToolResultBlock({ + toolUseId: event.result.toolUseId, + status: event.result.status, + content: newContent, + }) + } +} diff --git a/strands-ts/src/vended-plugins/context-offloader/search.ts b/strands-ts/src/vended-plugins/context-offloader/search.ts new file mode 100644 index 0000000000..88cfcd4351 --- /dev/null +++ b/strands-ts/src/vended-plugins/context-offloader/search.ts @@ -0,0 +1,153 @@ +/** + * Search and formatting utilities for offloaded content. + * + * Provides grep-like pattern matching and line-range random access over stored + * text content, with output capped to a character budget. + */ + +/** Cuts output at the last newline before {@link maxChars} and appends a truncation message. */ +function truncate(output: string, maxChars: number, message: string): string { + if (output.length <= maxChars) return output + + const cut = output.lastIndexOf('\n', maxChars) + const sliceEnd = cut > 0 ? cut : maxChars + + return output.slice(0, sliceEnd) + `\n\n[${message}]` +} + +/** Formats line indices with line numbers, `>` prefixes for matches, and `---` separators for gaps. */ +function formatLines(lines: string[], indices: number[], matchedSet: Set): string { + if (indices.length === 0) return '' + const padWidth = String(indices[indices.length - 1]! + 1).length + const output: string[] = [] + for (let i = 0; i < indices.length; i++) { + const idx = indices[i]! + if (i > 0 && idx > indices[i - 1]! + 1) output.push('---') + const lineNum = String(idx + 1).padStart(padWidth) + const prefix = matchedSet.has(idx) ? '>' : ' ' + output.push(`${prefix} ${lineNum}| ${lines[idx]}`) + } + return output.join('\n') +} + +// Mitigates ReDoS from overly long patterns. Short pathological patterns (e.g. `(a+)+$`) +// are still possible but unlikely since the agent provides the pattern, not end users. +const MAX_PATTERN_LENGTH = 200 + +/** Finds lines matching a pattern, expands with context, and formats with truncation. */ +function searchByPattern( + lines: string[], + pattern: string, + scopeStart: number, + scopeEnd: number, + contextLines: number, + maxChars: number, + scopeLabel: string +): string { + let regex: RegExp + const safeInput = + pattern.length > MAX_PATTERN_LENGTH + ? pattern.slice(0, MAX_PATTERN_LENGTH).replace(/[.*+?^${}()|[\]\\]/g, '\\$&') + : pattern + try { + regex = new RegExp(safeInput) + } catch { + regex = new RegExp(safeInput.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')) + } + + const matchedSet = new Set() + for (let i = scopeStart; i <= scopeEnd; i++) { + if (regex.test(lines[i]!)) matchedSet.add(i) + } + + if (matchedSet.size === 0) { + return `No matches found for pattern '${pattern}'${scopeLabel} (searched ${scopeEnd - scopeStart + 1} lines).` + } + + const visible = new Set() + for (const idx of matchedSet) { + for (let i = Math.max(scopeStart, idx - contextLines); i <= Math.min(scopeEnd, idx + contextLines); i++) { + visible.add(i) + } + } + + const safePattern = pattern.replace(/[\n\r/\]]/g, ' ').slice(0, 50) + const header = `[${matchedSet.size} match${matchedSet.size > 1 ? 'es' : ''} for /${safePattern}/${scopeLabel}]` + const body = formatLines( + lines, + [...visible].sort((a, b) => a - b), + matchedSet + ) + return truncate(`${header}\n\n${body}`, maxChars, 'output truncated, narrow your search') +} + +/** Formats a contiguous range of lines with truncation. */ +function searchByLineRange(lines: string[], start: number, end: number, totalLines: number, maxChars: number): string { + const indices = Array.from({ length: end - start + 1 }, (_, i) => start + i) + const header = `[Lines ${start + 1}-${end + 1} of ${totalLines}]` + const body = formatLines(lines, indices, new Set()) + return truncate(`${header}\n\n${body}`, maxChars, 'output truncated, narrow your range') +} + +const TEXT_APPLICATION_TYPES = new Set([ + 'application/json', + 'application/xml', + 'application/javascript', + 'application/typescript', + 'application/yaml', + 'application/x-yaml', + 'application/toml', + 'application/sql', + 'application/graphql', + 'application/xhtml+xml', +]) + +/** Returns whether the given MIME content type can be searched as text. */ +export function isSearchableContent(contentType: string): boolean { + return contentType.startsWith('text/') || TEXT_APPLICATION_TYPES.has(contentType) +} + +/** + * Search offloaded text content by pattern or line range. + * + * @param text - The full text content to search + * @param input - Search parameters (pattern, line_range, context_lines) + * @param maxChars - Maximum output size in characters; results are truncated beyond this + * @returns Formatted search results with line numbers, or an error/empty message + */ +export function searchContent( + text: string, + input: { + pattern?: string | undefined + line_range?: { start: number; end: number } | undefined + context_lines: number + }, + maxChars: number +): string { + const lines = text.split('\n') + const totalLines = lines.length + + if (totalLines === 0 || (totalLines === 1 && lines[0] === '')) { + return 'Content is empty (0 lines).' + } + + let scopeStart = 0 + let scopeEnd = totalLines - 1 + if (input.line_range) { + if (input.line_range.start > input.line_range.end) { + return `Error: line_range.start (${input.line_range.start}) must be <= line_range.end (${input.line_range.end}).` + } + if (input.line_range.start > totalLines) { + return `Error: line_range.start (${input.line_range.start}) is beyond content length (${totalLines} lines).` + } + scopeStart = input.line_range.start - 1 + scopeEnd = Math.min(input.line_range.end - 1, totalLines - 1) + } + + if (input.pattern) { + const scopeLabel = input.line_range ? ` in lines ${input.line_range.start}-${scopeEnd + 1}` : '' + return searchByPattern(lines, input.pattern, scopeStart, scopeEnd, input.context_lines, maxChars, scopeLabel) + } + + return searchByLineRange(lines, scopeStart, scopeEnd, totalLines, maxChars) +} diff --git a/strands-ts/src/vended-plugins/context-offloader/storage.ts b/strands-ts/src/vended-plugins/context-offloader/storage.ts new file mode 100644 index 0000000000..c6ec1dd7c7 --- /dev/null +++ b/strands-ts/src/vended-plugins/context-offloader/storage.ts @@ -0,0 +1,262 @@ +/** + * Storage backends for offloaded tool result content. + * + * This module defines the {@link Storage} interface and provides three built-in + * implementations: {@link InMemoryStorage}, {@link FileStorage}, and {@link S3Storage}. + * Each content block from a tool result is stored individually with its content type preserved. + */ + +/** + * Backend for storing and retrieving offloaded content blocks. + * + * Implement this interface to create custom storage backends (e.g., Redis, DynamoDB). + * The SDK ships three built-in implementations: {@link InMemoryStorage}, + * {@link FileStorage}, and {@link S3Storage}. + */ +export interface Storage { + /** + * Store content and return a reference identifier. + * + * @param key - Unique key for this content block + * @param content - Raw content bytes to store + * @param contentType - MIME type of the content (e.g., "text/plain", "image/png") + * @returns Reference string for later retrieval + */ + store(key: string, content: Uint8Array, contentType?: string): Promise + + /** + * Retrieve previously stored content by reference. + * + * @param reference - Reference returned by a previous {@link store} call + * @returns Content bytes and content type + * @throws Error if the reference is not found + */ + retrieve(reference: string): Promise<{ content: Uint8Array; contentType: string }> +} + +function sanitizeId(rawId: string): string { + return rawId + .replace(/\.\./g, '_') + .replace(/[/\\]/g, '_') + .replace(/[^\w\-.]/g, '_') +} + +/** + * In-memory storage backend. + * + * Useful for testing and serverless environments where disk access is not available. + * Content accumulates for the lifetime of this instance; call {@link clear} to free memory. + */ +export class InMemoryStorage implements Storage { + private _store = new Map() + private _counter = 0 + + /** {@inheritdoc} */ + async store(key: string, content: Uint8Array, contentType: string = 'text/plain'): Promise { + this._counter++ + const reference = `mem_${this._counter}_${key}` + this._store.set(reference, { content, contentType }) + return reference + } + + /** {@inheritdoc} */ + async retrieve(reference: string): Promise<{ content: Uint8Array; contentType: string }> { + const entry = this._store.get(reference) + if (!entry) { + throw new Error(`Reference not found: ${reference}`) + } + return entry + } + + /** Remove all stored content. */ + clear(): void { + this._store.clear() + } +} + +/** + * File-based storage backend. + * + * Stores offloaded content as files on disk. File extensions are derived from the + * content type. A `.metadata.json` sidecar file tracks content types across restarts. + * References are file paths preserving the configured artifact directory form. + * + * @param artifactDir - Directory path where artifact files will be stored + */ +export class FileStorage implements Storage { + private static readonly METADATA_FILE = '.metadata.json' + private readonly _artifactDir: string + private _counter = 0 + private _contentTypes: Record = {} + private _metadataLoaded = false + private _metadataWriteChain: Promise = Promise.resolve() + + constructor(artifactDir: string = './artifacts') { + this._artifactDir = artifactDir + } + + private static _extensionFor(contentType: string): string { + if (contentType === 'text/plain') return '.txt' + return `.${contentType.split('/').pop()}` + } + + private async _ensureDir(): Promise { + const fs = await import('node:fs/promises') + await fs.mkdir(this._artifactDir, { recursive: true }) + if (!this._metadataLoaded) { + this._contentTypes = await this._loadMetadata(fs) + this._metadataLoaded = true + } + return fs + } + + private async _loadMetadata(fs: typeof import('node:fs/promises')): Promise> { + const path = await import('node:path') + const metadataPath = path.join(this._artifactDir, FileStorage.METADATA_FILE) + try { + const raw = await fs.readFile(metadataPath, 'utf-8') + return JSON.parse(raw) as Record + } catch { + return {} + } + } + + private async _saveMetadata(fs: typeof import('node:fs/promises')): Promise { + const path = await import('node:path') + const metadataPath = path.join(this._artifactDir, FileStorage.METADATA_FILE) + await fs.writeFile(metadataPath, JSON.stringify(this._contentTypes), 'utf-8') + } + + /** {@inheritdoc} */ + async store(key: string, content: Uint8Array, contentType: string = 'text/plain'): Promise { + const fs = await this._ensureDir() + const path = await import('node:path') + + const sanitizedKey = sanitizeId(key) + const timestampMs = Date.now() + this._counter++ + const ext = FileStorage._extensionFor(contentType) + const filename = `${timestampMs}_${this._counter}_${sanitizedKey}${ext}` + + this._contentTypes[filename] = contentType + this._metadataWriteChain = this._metadataWriteChain.then(() => this._saveMetadata(fs)) + await this._metadataWriteChain + + const filePath = path.join(this._artifactDir, filename) + await fs.writeFile(filePath, content) + + return filePath + } + + /** {@inheritdoc} */ + async retrieve(reference: string): Promise<{ content: Uint8Array; contentType: string }> { + const fs = await this._ensureDir() + const path = await import('node:path') + + const filePath = path.resolve(this._artifactDir, reference) + const resolvedDir = path.resolve(this._artifactDir) + if (!filePath.startsWith(resolvedDir)) { + throw new Error(`Reference not found: ${reference}`) + } + + const filename = path.basename(filePath) + + try { + const content = await fs.readFile(filePath) + const contentType = this._contentTypes[filename] ?? 'application/octet-stream' + return { content: new Uint8Array(content), contentType } + } catch { + throw new Error(`Reference not found: ${reference}`) + } + } +} + +/** + * S3-based storage backend. + * + * Stores offloaded content as S3 objects. Content type is preserved as S3 object metadata. + * References are `s3://` URIs for direct access via AWS CLI or SDK. + * + * @param bucket - S3 bucket name + * @param options - Optional configuration (prefix, region, pre-configured S3Client) + */ +export class S3Storage implements Storage { + private readonly _bucket: string + private readonly _prefix: string + private _client: import('@aws-sdk/client-s3').S3Client | undefined + private readonly _region: string + private _counter = 0 + + constructor( + bucket: string, + options?: { prefix?: string; region?: string; s3Client?: import('@aws-sdk/client-s3').S3Client } + ) { + this._bucket = bucket + this._prefix = options?.prefix ? options.prefix.replace(/\/+$/, '') + '/' : '' + this._client = options?.s3Client + this._region = options?.region ?? 'us-east-1' + } + + private async _getClient(): Promise { + if (this._client) return this._client + const { S3Client } = await import('@aws-sdk/client-s3') + this._client = new S3Client({ region: this._region }) + return this._client + } + + /** {@inheritdoc} */ + async store(key: string, content: Uint8Array, contentType: string = 'text/plain'): Promise { + const client = await this._getClient() + const { PutObjectCommand } = await import('@aws-sdk/client-s3') + + const sanitizedKey = sanitizeId(key) + const timestampMs = Date.now() + this._counter++ + const s3Key = `${this._prefix}${timestampMs}_${this._counter}_${sanitizedKey}` + + await client.send( + new PutObjectCommand({ + Bucket: this._bucket, + Key: s3Key, + Body: content, + ContentType: contentType, + }) + ) + + return `s3://${this._bucket}/${s3Key}` + } + + /** {@inheritdoc} */ + async retrieve(reference: string): Promise<{ content: Uint8Array; contentType: string }> { + const client = await this._getClient() + const { GetObjectCommand } = await import('@aws-sdk/client-s3') + + // Accept both s3:// URIs and raw keys + let s3Key = reference + const uriMatch = reference.match(/^s3:\/\/([^/]+)\/(.+)$/) + if (uriMatch?.[1] && uriMatch[2]) { + if (uriMatch[1] !== this._bucket) { + throw new Error(`Reference not found: ${reference} (bucket mismatch)`) + } + s3Key = uriMatch[2] + } + + try { + const response = await client.send( + new GetObjectCommand({ + Bucket: this._bucket, + Key: s3Key, + }) + ) + const body = await response.Body?.transformToByteArray() + if (!body) throw new Error(`Reference not found: ${reference}`) + const contentType = response.ContentType ?? 'application/octet-stream' + return { content: new Uint8Array(body), contentType } + } catch (error: unknown) { + if (error instanceof Error && error.name === 'NoSuchKey') { + throw new Error(`Reference not found: ${reference}`) + } + throw error + } + } +} diff --git a/strands-ts/src/vended-plugins/index.ts b/strands-ts/src/vended-plugins/index.ts new file mode 100644 index 0000000000..7686216176 --- /dev/null +++ b/strands-ts/src/vended-plugins/index.ts @@ -0,0 +1,11 @@ +/** + * Barrel export for all vended plugins. + * + * Provides a single import path for consumers who want all built-in plugins: + * ```typescript + * import { AgentSkills, ContextOffloader, InMemoryStorage } from '@strands-agents/sdk/vended-plugins' + * ``` + */ + +export * from './skills/index.js' +export * from './context-offloader/index.js' diff --git a/strands-ts/src/vended-plugins/skills/__tests__/agent-skills.test.node.ts b/strands-ts/src/vended-plugins/skills/__tests__/agent-skills.test.node.ts new file mode 100644 index 0000000000..e8bd6f3e5d --- /dev/null +++ b/strands-ts/src/vended-plugins/skills/__tests__/agent-skills.test.node.ts @@ -0,0 +1,582 @@ +import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest' +import { AgentSkills } from '../agent-skills.js' +import { Skill } from '../skill.js' +import { BeforeInvocationEvent } from '../../../hooks/events.js' +import { TextBlock, CachePointBlock } from '../../../types/messages.js' +import { createMockAgent, invokeTrackedHook, type MockAgent } from '../../../__fixtures__/agent-helpers.js' +import { promises as fs } from 'fs' +import * as path from 'path' +import { tmpdir } from 'os' + +describe('AgentSkills', () => { + let testDir: string + + const createSkillDir = async ( + name: string, + content: string, + extraFiles?: Record + ): Promise => { + const dirPath = path.join(testDir, name) + await fs.mkdir(dirPath, { recursive: true }) + await fs.writeFile(path.join(dirPath, 'SKILL.md'), content, 'utf-8') + if (extraFiles) { + for (const [filePath, fileContent] of Object.entries(extraFiles)) { + const fullPath = path.join(dirPath, filePath) + await fs.mkdir(path.dirname(fullPath), { recursive: true }) + await fs.writeFile(fullPath, fileContent, 'utf-8') + } + } + return dirPath + } + + const makeSkill = (name: string, description = `Description of ${name}`, instructions = `Instructions for ${name}`) => + new Skill({ name, description, instructions }) + + beforeEach(async () => { + testDir = path.join(tmpdir(), `agent-skills-test-${Date.now()}-${Math.random().toString(36).slice(2)}`) + await fs.mkdir(testDir, { recursive: true }) + }) + + afterEach(async () => { + await fs.rm(testDir, { recursive: true, force: true }) + }) + + // ── Constructor & skill resolution ────────────────────────────────── + + describe('constructor', () => { + it('resolves Skill instances directly', async () => { + const skill = makeSkill('my-skill') + const plugin = new AgentSkills({ skills: [skill] }) + expect(await plugin.getAvailableSkills()).toHaveLength(1) + expect((await plugin.getAvailableSkills())[0]!.name).toBe('my-skill') + }) + + it('resolves a skill directory path', async () => { + await createSkillDir('my-skill', '---\nname: my-skill\ndescription: A skill\n---\nBody.') + const plugin = new AgentSkills({ skills: [path.join(testDir, 'my-skill')] }) + expect(await plugin.getAvailableSkills()).toHaveLength(1) + }) + + it('resolves a parent directory with multiple skills', async () => { + await createSkillDir('skill-a', '---\nname: skill-a\ndescription: Skill A\n---\nA.') + await createSkillDir('skill-b', '---\nname: skill-b\ndescription: Skill B\n---\nB.') + const plugin = new AgentSkills({ skills: [testDir] }) + expect(await plugin.getAvailableSkills()).toHaveLength(2) + }) + + it('handles mixed sources', async () => { + await createSkillDir('file-skill', '---\nname: file-skill\ndescription: From file\n---\nBody.') + const directSkill = makeSkill('direct-skill') + const plugin = new AgentSkills({ + skills: [directSkill, path.join(testDir, 'file-skill')], + }) + expect(await plugin.getAvailableSkills()).toHaveLength(2) + }) + + it('warns on duplicate names and keeps the last', async () => { + const skill1 = makeSkill('dup', 'First') + const skill2 = makeSkill('dup', 'Second') + const plugin = new AgentSkills({ skills: [skill1, skill2] }) + expect(await plugin.getAvailableSkills()).toHaveLength(1) + expect((await plugin.getAvailableSkills())[0]!.description).toBe('Second') + }) + + it('warns and skips non-existent paths', async () => { + const plugin = new AgentSkills({ skills: ['/does/not/exist'] }) + expect(await plugin.getAvailableSkills()).toHaveLength(0) + }) + + it('gracefully handles a path with malformed SKILL.md', async () => { + const dirPath = path.join(testDir, 'bad-skill') + await fs.mkdir(dirPath, { recursive: true }) + await fs.writeFile(path.join(dirPath, 'SKILL.md'), 'totally broken, no frontmatter at all', 'utf-8') + + const plugin = new AgentSkills({ skills: [dirPath] }) + expect(await plugin.getAvailableSkills()).toHaveLength(0) + }) + + it('loads valid skills from a parent dir containing malformed siblings', async () => { + await fs.mkdir(path.join(testDir, 'good-skill'), { recursive: true }) + await fs.writeFile( + path.join(testDir, 'good-skill', 'SKILL.md'), + '---\nname: good-skill\ndescription: Works\n---\nBody.', + 'utf-8' + ) + await fs.mkdir(path.join(testDir, 'bad-skill'), { recursive: true }) + await fs.writeFile(path.join(testDir, 'bad-skill', 'SKILL.md'), 'no frontmatter', 'utf-8') + + const plugin = new AgentSkills({ skills: [testDir] }) + const skills = await plugin.getAvailableSkills() + expect(skills).toHaveLength(1) + expect(skills[0]!.name).toBe('good-skill') + }) + }) + + // ── Plugin interface ──────────────────────────────────────────────── + + describe('plugin interface', () => { + it('has the correct name', () => { + const plugin = new AgentSkills({ skills: [makeSkill('s')] }) + expect(plugin.name).toBe('strands:agent-skills') + }) + + it('returns one tool named skills from getTools', () => { + const plugin = new AgentSkills({ skills: [makeSkill('s')] }) + const tools = plugin.getTools() + expect(tools).toHaveLength(1) + expect(tools[0]!.name).toBe('skills') + }) + + it('registers a BeforeInvocationEvent hook in initAgent', async () => { + const plugin = new AgentSkills({ skills: [makeSkill('s')] }) + const agent = createMockAgent() + await plugin.initAgent(agent) + expect(agent.trackedHooks).toHaveLength(1) + expect(agent.trackedHooks[0]!.eventType).toBe(BeforeInvocationEvent) + }) + }) + + // ── System prompt injection ───────────────────────────────────────── + + describe('system prompt injection', () => { + let plugin: AgentSkills + let agent: MockAgent + + beforeEach(async () => { + plugin = new AgentSkills({ + skills: [makeSkill('pdf-skill', 'Process PDFs')], + }) + agent = createMockAgent() + await plugin.initAgent(agent) + }) + + const fireBeforeInvocation = async () => { + await invokeTrackedHook(agent, new BeforeInvocationEvent({ agent: agent as any, invocationState: {} })) + } + + it('injects into undefined system prompt', async () => { + delete (agent as any).systemPrompt + await fireBeforeInvocation() + expect(typeof agent.systemPrompt).toBe('string') + expect(agent.systemPrompt as unknown as string).toContain('') + expect(agent.systemPrompt as unknown as string).toContain('pdf-skill') + }) + + it('injects into string system prompt', async () => { + agent.systemPrompt = 'You are a helpful assistant.' + await fireBeforeInvocation() + const prompt = agent.systemPrompt as string + expect(prompt).toContain('You are a helpful assistant.') + expect(prompt).toContain('') + expect(prompt).toContain('pdf-skill') + }) + + it('injects into SystemContentBlock[] prompt', async () => { + agent.systemPrompt = [new TextBlock('You are helpful.'), new CachePointBlock({ cacheType: 'default' })] + await fireBeforeInvocation() + const blocks = agent.systemPrompt as any[] + expect(blocks.length).toBe(3) + // Original blocks preserved + expect(blocks[0]).toBeInstanceOf(TextBlock) + expect((blocks[0] as TextBlock).text).toBe('You are helpful.') + expect(blocks[1]).toBeInstanceOf(CachePointBlock) + // New skills block appended + expect(blocks[2]).toBeInstanceOf(TextBlock) + expect((blocks[2] as TextBlock).text).toContain('') + }) + + it('is idempotent — re-injection replaces previous block', async () => { + agent.systemPrompt = 'Base prompt.' + await fireBeforeInvocation() + const first = agent.systemPrompt as string + const skillsCount = (first.match(//g) ?? []).length + expect(skillsCount).toBe(1) + + // Fire again + await fireBeforeInvocation() + const second = agent.systemPrompt as string + const skillsCount2 = (second.match(//g) ?? []).length + expect(skillsCount2).toBe(1) + expect(second).toContain('Base prompt.') + }) + + it('is idempotent with SystemContentBlock[] prompt', async () => { + agent.systemPrompt = [new TextBlock('Base.')] + await fireBeforeInvocation() + await fireBeforeInvocation() + const blocks = agent.systemPrompt as any[] + // Original block + one skills block (not two) + const skillsBlocks = blocks.filter((b: any) => b instanceof TextBlock && b.text.includes('')) + expect(skillsBlocks).toHaveLength(1) + }) + + it('preserves external modifications to system prompt', async () => { + agent.systemPrompt = 'Original.' + await fireBeforeInvocation() + + // Simulate external modification + agent.systemPrompt = (agent.systemPrompt as string).replace('Original.', 'Modified.') + + await fireBeforeInvocation() + const prompt = agent.systemPrompt as string + expect(prompt).toContain('Modified.') + expect(prompt).toContain('') + }) + + it('XML-escapes special characters in skill metadata', async () => { + const plugin2 = new AgentSkills({ + skills: [makeSkill('test-skill', 'Use when: user says & "goodbye"')], + }) + const agent2 = createMockAgent() + await plugin2.initAgent(agent2) + + const hook = agent2.trackedHooks[0]! + await hook.callback(new BeforeInvocationEvent({ agent: agent2 as any, invocationState: {} })) + + const prompt = agent2.systemPrompt as string + expect(prompt).toContain('<hello>') + expect(prompt).toContain('&') + expect(prompt).toContain('"goodbye"') + }) + + it('includes skill location when path is set', async () => { + const dirPath = await createSkillDir( + 'located-skill', + '---\nname: located-skill\ndescription: Has a path\n---\nBody.' + ) + const filePlugin = new AgentSkills({ skills: [dirPath] }) + const fileAgent = createMockAgent() + await filePlugin.initAgent(fileAgent) + await invokeTrackedHook(fileAgent, new BeforeInvocationEvent({ agent: fileAgent as any, invocationState: {} })) + + const prompt = fileAgent.systemPrompt as string + expect(prompt).toContain('') + expect(prompt).toContain('SKILL.md') + }) + + it('shows "no skills available" when empty', async () => { + const emptyPlugin = new AgentSkills({ skills: [] }) + const emptyAgent = createMockAgent() + await emptyPlugin.initAgent(emptyAgent) + await invokeTrackedHook(emptyAgent, new BeforeInvocationEvent({ agent: emptyAgent as any, invocationState: {} })) + + const prompt = emptyAgent.systemPrompt as string + expect(prompt).toContain('No skills are currently available.') + }) + + it('injects into null system prompt', async () => { + agent.systemPrompt = null as any + await fireBeforeInvocation() + expect(typeof agent.systemPrompt).toBe('string') + expect(agent.systemPrompt as unknown as string).toContain('') + expect(agent.systemPrompt as unknown as string).toContain('pdf-skill') + }) + + it('reflects updated skills after setAvailableSkills', async () => { + agent.systemPrompt = 'Base.' + await fireBeforeInvocation() + expect(agent.systemPrompt as string).toContain('pdf-skill') + + plugin.setAvailableSkills([makeSkill('new-skill', 'A new skill')]) + await fireBeforeInvocation() + const prompt = agent.systemPrompt as string + expect(prompt).toContain('new-skill') + expect(prompt).not.toContain('pdf-skill') + expect(prompt).toContain('Base.') + }) + + it('lists all skills when multiple are available', async () => { + const multiPlugin = new AgentSkills({ + skills: [makeSkill('skill-a', 'First'), makeSkill('skill-b', 'Second'), makeSkill('skill-c', 'Third')], + }) + const multiAgent = createMockAgent() + await multiPlugin.initAgent(multiAgent) + await invokeTrackedHook(multiAgent, new BeforeInvocationEvent({ agent: multiAgent as any, invocationState: {} })) + + const prompt = multiAgent.systemPrompt as string + expect(prompt).toContain('skill-a') + expect(prompt).toContain('skill-b') + expect(prompt).toContain('skill-c') + expect(prompt).toContain('First') + expect(prompt).toContain('Second') + expect(prompt).toContain('Third') + }) + }) + + // ── Tool callback ─────────────────────────────────────────────────── + + describe('tool callback', () => { + let plugin: AgentSkills + let agent: MockAgent + + beforeEach(async () => { + plugin = new AgentSkills({ + skills: [ + new Skill({ + name: 'test-skill', + description: 'A test skill', + instructions: '# Test\nDo the thing.', + allowedTools: ['bash'], + compatibility: 'v1.0+', + }), + ], + }) + agent = createMockAgent() + await plugin.initAgent(agent) + }) + + const invokeTool = async (skillName: string): Promise => { + const tools = plugin.getTools() + const skillsTool = tools[0]! + // Use the stream method to get the result + const gen = skillsTool.stream({ + toolUse: { name: 'skills', toolUseId: 'test-id', input: { skill_name: skillName } }, + agent: agent as any, + invocationState: {}, + interrupt: () => { + throw new Error('interrupt not available in mock context') + }, + }) + let result = await gen.next() + while (!result.done) { + result = await gen.next() + } + // Extract text from the tool result + const content = result.value.content + return content.map((b: any) => b.text ?? '').join('') + } + + it('returns instructions for a valid skill', async () => { + const result = await invokeTool('test-skill') + expect(result).toContain('# Test') + expect(result).toContain('Do the thing.') + }) + + it('includes metadata in the response', async () => { + const result = await invokeTool('test-skill') + expect(result).toContain('Allowed tools: bash') + expect(result).toContain('Compatibility: v1.0+') + }) + + it('returns error for unknown skill', async () => { + const result = await invokeTool('nonexistent') + expect(result).toContain("Skill 'nonexistent' not found") + expect(result).toContain('test-skill') + }) + + it('tracks activated skills in appState', async () => { + await invokeTool('test-skill') + const activated = plugin.getActivatedSkills(agent as any) + expect(activated).toEqual(['test-skill']) + }) + + it('maintains activation order without duplicates', async () => { + // Add a second skill + plugin.setAvailableSkills([makeSkill('skill-a'), makeSkill('skill-b')]) + + await invokeTool('skill-a') + await invokeTool('skill-b') + await invokeTool('skill-a') // re-activate + + const activated = plugin.getActivatedSkills(agent as any) + expect(activated).toEqual(['skill-b', 'skill-a']) + }) + + it('handles skill with no instructions', async () => { + plugin.setAvailableSkills([new Skill({ name: 'empty', description: 'No instructions' })]) + const result = await invokeTool('empty') + expect(result).toContain("Skill 'empty' activated (no instructions available).") + }) + + it('returns validation error for empty skill_name', async () => { + const result = await invokeTool('') + // z.string().min(1) rejects empty strings at the schema level + expect(result.toLowerCase()).toContain('too_small') + }) + }) + + // ── Resource listing ──────────────────────────────────────────────── + + describe('resource listing', () => { + it('lists files from scripts/, references/, assets/', async () => { + const dirPath = await createSkillDir( + 'resource-skill', + '---\nname: resource-skill\ndescription: Has resources\n---\nBody.', + { + 'scripts/setup.sh': '#!/bin/bash', + 'references/api.md': '# API Docs', + 'assets/logo.png': 'binary', + } + ) + const plugin2 = new AgentSkills({ skills: [dirPath] }) + const agent2 = createMockAgent() + await plugin2.initAgent(agent2) + + const tools = plugin2.getTools() + const gen = tools[0]!.stream({ + toolUse: { name: 'skills', toolUseId: 'id', input: { skill_name: 'resource-skill' } }, + agent: agent2 as any, + invocationState: {}, + interrupt: () => { + throw new Error('interrupt not available in mock context') + }, + }) + let result = await gen.next() + while (!result.done) result = await gen.next() + const text = result.value.content.map((b: any) => b.text ?? '').join('') + + expect(text).toContain('scripts/setup.sh') + expect(text).toContain('references/api.md') + expect(text).toContain('assets/logo.png') + }) + + it('handles missing resource directories gracefully', async () => { + const dirPath = await createSkillDir( + 'no-resources', + '---\nname: no-resources\ndescription: No extras\n---\nBody.' + ) + const plugin2 = new AgentSkills({ skills: [dirPath] }) + const agent2 = createMockAgent() + await plugin2.initAgent(agent2) + + const tools = plugin2.getTools() + const gen = tools[0]!.stream({ + toolUse: { name: 'skills', toolUseId: 'id', input: { skill_name: 'no-resources' } }, + agent: agent2 as any, + invocationState: {}, + interrupt: () => { + throw new Error('interrupt not available in mock context') + }, + }) + let result = await gen.next() + while (!result.done) result = await gen.next() + const text = result.value.content.map((b: any) => b.text ?? '').join('') + + expect(text).not.toContain('Available resources') + }) + + it('truncates at maxResourceFiles', async () => { + // Create more files than the limit + const files: Record = {} + for (let i = 0; i < 5; i++) { + files[`scripts/file${i}.sh`] = `script ${i}` + } + const dirPath = await createSkillDir( + 'many-files', + '---\nname: many-files\ndescription: Many resources\n---\nBody.', + files + ) + const plugin2 = new AgentSkills({ skills: [dirPath], maxResourceFiles: 3 }) + const agent2 = createMockAgent() + await plugin2.initAgent(agent2) + + const tools = plugin2.getTools() + const gen = tools[0]!.stream({ + toolUse: { name: 'skills', toolUseId: 'id', input: { skill_name: 'many-files' } }, + agent: agent2 as any, + invocationState: {}, + interrupt: () => { + throw new Error('interrupt not available in mock context') + }, + }) + let result = await gen.next() + while (!result.done) result = await gen.next() + const text = result.value.content.map((b: any) => b.text ?? '').join('') + + expect(text).toContain('truncated at 3 files') + }) + }) + + // ── setAvailableSkills / getAvailableSkills ───────────────────────── + + describe('setAvailableSkills', () => { + it('replaces all skills', async () => { + const plugin2 = new AgentSkills({ skills: [makeSkill('original')] }) + expect(await plugin2.getAvailableSkills()).toHaveLength(1) + + plugin2.setAvailableSkills([makeSkill('new-a'), makeSkill('new-b')]) + expect(await plugin2.getAvailableSkills()).toHaveLength(2) + expect((await plugin2.getAvailableSkills()).map((s) => s.name).sort()).toEqual(['new-a', 'new-b']) + }) + }) + + // ── URL skill resolution ────────────────────────────────────────────── + + describe('URL skill resolution', () => { + const SAMPLE_CONTENT = '---\nname: url-skill\ndescription: A URL skill\n---\n# Instructions\n' + + const mockFetchSuccess = (content: string) => { + vi.spyOn(globalThis, 'fetch').mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + text: () => Promise.resolve(content), + } as Response) + } + + afterEach(() => { + vi.restoreAllMocks() + }) + + it('resolves a URL string as a skill source', async () => { + mockFetchSuccess(SAMPLE_CONTENT) + + const plugin = new AgentSkills({ skills: ['https://example.com/SKILL.md'] }) + await plugin.initAgent(createMockAgent()) + + expect(await plugin.getAvailableSkills()).toHaveLength(1) + expect((await plugin.getAvailableSkills())[0]!.name).toBe('url-skill') + }) + + it('resolves a mix of URL and local filesystem sources', async () => { + mockFetchSuccess(SAMPLE_CONTENT) + + await createSkillDir('local-skill', '---\nname: local-skill\ndescription: A local skill\n---\nBody.') + + const plugin = new AgentSkills({ + skills: ['https://example.com/SKILL.md', path.join(testDir, 'local-skill')], + }) + await plugin.initAgent(createMockAgent()) + + expect(await plugin.getAvailableSkills()).toHaveLength(2) + const names = new Set((await plugin.getAvailableSkills()).map((s) => s.name)) + expect(names).toEqual(new Set(['url-skill', 'local-skill'])) + }) + + it('skips a failed URL fetch gracefully', async () => { + vi.spyOn(globalThis, 'fetch').mockResolvedValue({ + ok: false, + status: 404, + statusText: 'Not Found', + text: () => Promise.resolve(''), + } as Response) + + const plugin = new AgentSkills({ skills: ['https://example.com/broken/SKILL.md'] }) + await plugin.initAgent(createMockAgent()) + + expect(await plugin.getAvailableSkills()).toHaveLength(0) + }) + + it('warns on duplicate skill names from URLs', async () => { + mockFetchSuccess(SAMPLE_CONTENT) + + const plugin = new AgentSkills({ + skills: ['https://example.com/a/SKILL.md', 'https://example.com/b/SKILL.md'], + }) + await plugin.initAgent(createMockAgent()) + + expect(await plugin.getAvailableSkills()).toHaveLength(1) + }) + + it('awaits URL sources in initAgent', async () => { + mockFetchSuccess(SAMPLE_CONTENT) + + const plugin = new AgentSkills({ skills: ['https://example.com/SKILL.md'] }) + const agent = createMockAgent() + await plugin.initAgent(agent) + + expect(await plugin.getAvailableSkills()).toHaveLength(1) + expect(agent.trackedHooks).toHaveLength(1) + }) + }) +}) diff --git a/strands-ts/src/vended-plugins/skills/__tests__/skill.test.node.ts b/strands-ts/src/vended-plugins/skills/__tests__/skill.test.node.ts new file mode 100644 index 0000000000..52e7061ebc --- /dev/null +++ b/strands-ts/src/vended-plugins/skills/__tests__/skill.test.node.ts @@ -0,0 +1,606 @@ +import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest' +import { Skill } from '../skill.js' +import { promises as fs } from 'fs' +import * as path from 'path' +import { tmpdir } from 'os' + +describe('Skill', () => { + let testDir: string + + const createSkillDir = async (name: string, content: string, filename = 'SKILL.md'): Promise => { + const dirPath = path.join(testDir, name) + await fs.mkdir(dirPath, { recursive: true }) + await fs.writeFile(path.join(dirPath, filename), content, 'utf-8') + return dirPath + } + + beforeEach(async () => { + testDir = path.join(tmpdir(), `skill-test-${Date.now()}-${Math.random().toString(36).slice(2)}`) + await fs.mkdir(testDir, { recursive: true }) + }) + + afterEach(async () => { + await fs.rm(testDir, { recursive: true, force: true }) + }) + + describe('constructor', () => { + it('creates a skill with required fields', () => { + const skill = new Skill({ name: 'test-skill', description: 'A test skill' }) + expect(skill).toEqual( + expect.objectContaining({ + name: 'test-skill', + description: 'A test skill', + instructions: '', + path: undefined, + allowedTools: undefined, + metadata: {}, + license: undefined, + compatibility: undefined, + }) + ) + }) + + it('creates a skill with all fields', () => { + const skill = new Skill({ + name: 'full-skill', + description: 'Full description', + instructions: '# Instructions\nDo things', + path: '/some/path', + allowedTools: ['bash', 'file-editor'], + metadata: { author: 'test' }, + license: 'Apache-2.0', + compatibility: 'v1.0+', + }) + expect(skill).toEqual( + expect.objectContaining({ + name: 'full-skill', + description: 'Full description', + instructions: '# Instructions\nDo things', + path: '/some/path', + allowedTools: ['bash', 'file-editor'], + metadata: { author: 'test' }, + license: 'Apache-2.0', + compatibility: 'v1.0+', + }) + ) + }) + }) + + describe('fromContent', () => { + it('parses valid SKILL.md content', () => { + const content = `--- +name: my-skill +description: Does something useful +--- +# Instructions +Follow these steps.` + + const skill = Skill.fromContent(content) + expect(skill.name).toBe('my-skill') + expect(skill.description).toBe('Does something useful') + expect(skill.instructions).toBe('# Instructions\nFollow these steps.') + }) + + it('parses content with allowed-tools as space-delimited string', () => { + const content = `--- +name: my-skill +description: A skill +allowed-tools: bash file-editor +--- +Instructions here.` + + const skill = Skill.fromContent(content) + expect(skill.allowedTools).toEqual(['bash', 'file-editor']) + }) + + it('parses content with allowed-tools as YAML list', () => { + const content = `--- +name: my-skill +description: A skill +allowed-tools: + - bash + - file-editor +--- +Instructions here.` + + const skill = Skill.fromContent(content) + expect(skill.allowedTools).toEqual(['bash', 'file-editor']) + }) + + it('parses content with allowed_tools underscore variant', () => { + const content = `--- +name: my-skill +description: A skill +allowed_tools: bash notebook +--- +Instructions here.` + + const skill = Skill.fromContent(content) + expect(skill.allowedTools).toEqual(['bash', 'notebook']) + }) + + it('parses content with metadata', () => { + const content = `--- +name: my-skill +description: A skill +metadata: + author: test-user + version: 1 +--- +Body.` + + const skill = Skill.fromContent(content) + expect(skill.metadata).toEqual({ author: 'test-user', version: 1 }) + }) + + it('parses content with license and compatibility', () => { + const content = `--- +name: my-skill +description: A skill +license: MIT +compatibility: strands-agents >= 1.0 +--- +Body.` + + const skill = Skill.fromContent(content) + expect(skill.license).toBe('MIT') + expect(skill.compatibility).toBe('strands-agents >= 1.0') + }) + + it('throws if content does not start with ---', () => { + expect(() => Skill.fromContent('no frontmatter')).toThrow('SKILL.md must start with --- frontmatter delimiter') + }) + + it('throws if closing --- is missing', () => { + expect(() => Skill.fromContent('---\nname: test\n')).toThrow('SKILL.md frontmatter missing closing --- delimiter') + }) + + it('throws if name is missing', () => { + const content = `--- +description: no name +--- +Body.` + expect(() => Skill.fromContent(content)).toThrow("must have a 'name' field") + }) + + it('throws if description is missing', () => { + const content = `--- +name: my-skill +--- +Body.` + expect(() => Skill.fromContent(content)).toThrow("must have a 'description' field") + }) + + it('handles empty body', () => { + const content = `--- +name: my-skill +description: A skill +---` + + const skill = Skill.fromContent(content) + expect(skill.instructions).toBe('') + }) + + it('warns but does not throw for invalid name in lenient mode', () => { + const content = `--- +name: INVALID_NAME +description: A skill +--- +Body.` + + // Should not throw in lenient mode (default) + const skill = Skill.fromContent(content) + expect(skill.name).toBe('INVALID_NAME') + }) + + it('throws for invalid name in strict mode', () => { + const content = `--- +name: INVALID_NAME +description: A skill +--- +Body.` + + expect(() => Skill.fromContent(content, { strict: true })).toThrow('skill name should be') + }) + + it('throws for empty name in strict mode', () => { + const content = `--- +name: "" +description: A skill +--- +Body.` + + expect(() => Skill.fromContent(content)).toThrow("must have a 'name' field") + }) + + it('throws for name exceeding length limit in strict mode', () => { + const longName = 'a'.repeat(65) + const content = `--- +name: ${longName} +description: A skill +--- +Body.` + + expect(() => Skill.fromContent(content, { strict: true })).toThrow('exceeds 64 character limit') + }) + + it('throws for consecutive hyphens in strict mode', () => { + const content = `--- +name: my--skill +description: A skill +--- +Body.` + + expect(() => Skill.fromContent(content, { strict: true })).toThrow('consecutive hyphens') + }) + + it('handles body containing --- horizontal rules', () => { + const content = `--- +name: my-skill +description: A skill +--- +# Instructions + +First section. + +--- + +Second section after horizontal rule. + +--- + +Third section.` + + const skill = Skill.fromContent(content) + expect(skill.name).toBe('my-skill') + expect(skill.instructions).toContain('First section.') + expect(skill.instructions).toContain('---') + expect(skill.instructions).toContain('Third section.') + }) + + it('handles body with only whitespace after frontmatter', () => { + const content = `--- +name: my-skill +description: A skill +--- + + ` + + const skill = Skill.fromContent(content) + expect(skill.name).toBe('my-skill') + expect(skill.instructions).toBe('') + }) + + it('handles frontmatter value containing --- inline', () => { + const content = `--- +name: my-skill +description: Use this --- for special cases +--- +Body.` + + const skill = Skill.fromContent(content) + expect(skill.name).toBe('my-skill') + expect(skill.description).toBe('Use this --- for special cases') + }) + + it('ignores non-object metadata', () => { + const content = `--- +name: my-skill +description: A skill +metadata: just-a-string +--- +Body.` + + const skill = Skill.fromContent(content) + expect(skill.metadata).toEqual({}) + }) + + it('ignores array metadata', () => { + const content = `--- +name: my-skill +description: A skill +metadata: + - item1 + - item2 +--- +Body.` + + const skill = Skill.fromContent(content) + expect(skill.metadata).toEqual({}) + }) + + it('handles allowed-tools as empty string', () => { + const content = `--- +name: my-skill +description: A skill +allowed-tools: "" +--- +Body.` + + const skill = Skill.fromContent(content) + expect(skill.allowedTools).toBeUndefined() + }) + + it('filters null entries from allowed-tools array', () => { + const content = `--- +name: my-skill +description: A skill +allowed-tools: + - bash + - null + - file-editor +--- +Body.` + + const skill = Skill.fromContent(content) + expect(skill.allowedTools).toEqual(['bash', 'file-editor']) + }) + }) + + describe('fromFile', () => { + it('loads a skill from a directory', async () => { + const dirPath = await createSkillDir( + 'my-skill', + `--- +name: my-skill +description: A test skill +--- +# Instructions +Do the thing.` + ) + + const skill = Skill.fromFile(dirPath) + expect(skill.name).toBe('my-skill') + expect(skill.description).toBe('A test skill') + expect(skill.instructions).toBe('# Instructions\nDo the thing.') + expect(skill.path).toBe(dirPath) + }) + + it('loads a skill from a SKILL.md file path', async () => { + const dirPath = await createSkillDir( + 'my-skill', + `--- +name: my-skill +description: A test skill +--- +Body.` + ) + const filePath = path.join(dirPath, 'SKILL.md') + + const skill = Skill.fromFile(filePath) + expect(skill.name).toBe('my-skill') + expect(skill.path).toBe(dirPath) + }) + + it('finds lowercase skill.md as fallback', async () => { + const dirPath = await createSkillDir( + 'my-skill', + `--- +name: my-skill +description: A test skill +--- +Body.`, + 'skill.md' + ) + + const skill = Skill.fromFile(dirPath) + expect(skill.name).toBe('my-skill') + }) + + it('throws for non-existent path', () => { + expect(() => Skill.fromFile('/does/not/exist')).toThrow('does not exist') + }) + + it('warns when skill name does not match directory name', async () => { + const dirPath = await createSkillDir( + 'wrong-dir-name', + `--- +name: actual-skill-name +description: Mismatched name +--- +Body.` + ) + + // Should not throw in lenient mode + const skill = Skill.fromFile(dirPath) + expect(skill.name).toBe('actual-skill-name') + }) + + it('throws when skill name does not match directory in strict mode', async () => { + const dirPath = await createSkillDir( + 'wrong-dir-name', + `--- +name: actual-skill-name +description: Mismatched name +--- +Body.` + ) + + expect(() => Skill.fromFile(dirPath, { strict: true })).toThrow('does not match parent directory name') + }) + }) + + describe('fromDirectory', () => { + it('loads all skills from a directory', async () => { + await createSkillDir( + 'skill-a', + `--- +name: skill-a +description: First skill +--- +Instructions A.` + ) + await createSkillDir( + 'skill-b', + `--- +name: skill-b +description: Second skill +--- +Instructions B.` + ) + + const skills = Skill.fromDirectory(testDir) + expect(skills).toHaveLength(2) + expect(skills.map((s) => s.name).sort()).toEqual(['skill-a', 'skill-b']) + }) + + it('skips directories without SKILL.md', async () => { + await createSkillDir( + 'valid-skill', + `--- +name: valid-skill +description: Has SKILL.md +--- +Body.` + ) + // Create a directory without SKILL.md + await fs.mkdir(path.join(testDir, 'no-skill-md'), { recursive: true }) + await fs.writeFile(path.join(testDir, 'no-skill-md', 'README.md'), 'not a skill', 'utf-8') + + const skills = Skill.fromDirectory(testDir) + expect(skills).toHaveLength(1) + expect(skills[0]!.name).toBe('valid-skill') + }) + + it('skips non-directory children', async () => { + await createSkillDir( + 'valid-skill', + `--- +name: valid-skill +description: Has SKILL.md +--- +Body.` + ) + // Create a plain file in the parent directory + await fs.writeFile(path.join(testDir, 'some-file.txt'), 'not a directory', 'utf-8') + + const skills = Skill.fromDirectory(testDir) + expect(skills).toHaveLength(1) + }) + + it('skips skills with invalid content', async () => { + await createSkillDir( + 'valid-skill', + `--- +name: valid-skill +description: Good skill +--- +Body.` + ) + await createSkillDir( + 'bad-skill', + `--- +description: Missing name +--- +Body.` + ) + + const skills = Skill.fromDirectory(testDir) + expect(skills).toHaveLength(1) + expect(skills[0]!.name).toBe('valid-skill') + }) + + it('throws for non-existent directory', () => { + expect(() => Skill.fromDirectory('/does/not/exist')).toThrow('skills directory does not exist') + }) + + it('returns empty array for directory with no skills', async () => { + const skills = Skill.fromDirectory(testDir) + expect(skills).toEqual([]) + }) + + it('skips skills with completely broken SKILL.md (no frontmatter)', async () => { + await createSkillDir( + 'valid-skill', + `--- +name: valid-skill +description: Good skill +--- +Body.` + ) + await createSkillDir('broken-skill', 'totally broken, no frontmatter at all') + + const skills = Skill.fromDirectory(testDir) + expect(skills).toHaveLength(1) + expect(skills[0]!.name).toBe('valid-skill') + }) + }) + + describe('fromUrl', () => { + const SAMPLE_CONTENT = '---\nname: my-skill\ndescription: A remote skill\n---\nRemote instructions.\n' + + const mockFetchSuccess = (content: string) => { + vi.spyOn(globalThis, 'fetch').mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + text: () => Promise.resolve(content), + } as Response) + } + + afterEach(() => { + vi.restoreAllMocks() + }) + + it('returns a Skill from a valid URL', async () => { + mockFetchSuccess(SAMPLE_CONTENT) + + const skill = await Skill.fromUrl('https://raw.githubusercontent.com/org/repo/main/SKILL.md') + + expect(skill).toBeInstanceOf(Skill) + expect(skill.name).toBe('my-skill') + expect(skill.description).toBe('A remote skill') + expect(skill.instructions).toContain('Remote instructions.') + expect(skill.path).toBeUndefined() + }) + + it('rejects non-HTTPS URLs', async () => { + await expect(Skill.fromUrl('./local-path')).rejects.toThrow('not a valid HTTPS URL') + }) + + it('rejects http:// URLs', async () => { + await expect(Skill.fromUrl('http://example.com/SKILL.md')).rejects.toThrow('not a valid HTTPS URL') + }) + + it('throws on HTTP error responses', async () => { + vi.spyOn(globalThis, 'fetch').mockResolvedValue({ + ok: false, + status: 404, + statusText: 'Not Found', + text: () => Promise.resolve(''), + } as Response) + + await expect(Skill.fromUrl('https://example.com/SKILL.md')).rejects.toThrow('HTTP 404') + }) + + it('throws on network errors', async () => { + vi.spyOn(globalThis, 'fetch').mockRejectedValue(new Error('Connection refused')) + + await expect(Skill.fromUrl('https://example.com/SKILL.md')).rejects.toThrow('failed to fetch') + }) + + it('forwards strict mode to fromContent', async () => { + const badContent = '---\nname: BAD_NAME\ndescription: Bad\n---\nBody.' + mockFetchSuccess(badContent) + + await expect(Skill.fromUrl('https://example.com/SKILL.md', { strict: true })).rejects.toThrow( + 'skill name should be' + ) + }) + + it('throws on invalid content (e.g. HTML page)', async () => { + mockFetchSuccess('Not a SKILL.md') + + await expect(Skill.fromUrl('https://example.com/SKILL.md')).rejects.toThrow('frontmatter') + }) + }) + + describe('classmethods', () => { + it('has fromFile, fromContent, fromDirectory, and fromUrl', () => { + expect(typeof Skill.fromFile).toBe('function') + expect(typeof Skill.fromContent).toBe('function') + expect(typeof Skill.fromDirectory).toBe('function') + expect(typeof Skill.fromUrl).toBe('function') + }) + }) +}) diff --git a/strands-ts/src/vended-plugins/skills/agent-skills.ts b/strands-ts/src/vended-plugins/skills/agent-skills.ts new file mode 100644 index 0000000000..5154c61d96 --- /dev/null +++ b/strands-ts/src/vended-plugins/skills/agent-skills.ts @@ -0,0 +1,493 @@ +/** + * AgentSkills plugin for integrating Agent Skills into Strands agents. + * + * This module provides the AgentSkills class that implements the Plugin + * interface to add Agent Skills support. The plugin registers a tool for + * activating skills and injects skill metadata into the system prompt. + */ + +import { readdirSync, statSync, existsSync } from 'fs' +import { join, resolve, relative, sep } from 'path' +import { z } from 'zod' +import { tool } from '../../tools/tool-factory.js' +import { BeforeInvocationEvent } from '../../hooks/events.js' +import { TextBlock, type SystemContentBlock } from '../../types/messages.js' +import { logger } from '../../logging/logger.js' +import { Skill } from './skill.js' +import type { Plugin } from '../../plugins/plugin.js' +import type { LocalAgent } from '../../types/agent.js' +import type { Tool } from '../../tools/tool.js' +import type { ToolContext } from '../../tools/tool.js' + +/** A single skill source: filesystem path string, HTTPS URL string, or Skill instance. */ +export type SkillSource = string | Skill + +/** Configuration for the AgentSkills plugin. */ +export interface AgentSkillsConfig { + /** + * One or more skill sources. Each element can be: + * - A `Skill` instance + * - A path to a skill directory (containing SKILL.md) + * - A path to a parent directory (containing skill subdirectories) + * - An `https://` URL pointing directly to raw SKILL.md content + */ + skills: SkillSource[] + + /** Maximum number of resource files to list in skill responses. Defaults to 20. */ + maxResourceFiles?: number | undefined + + /** If true, throw on skill validation issues. If false (default), warn and load anyway. */ + strict?: boolean | undefined + + /** Custom key for storing plugin state in `agent.appState`. Defaults to `'agent_skills'`. */ + stateKey?: string | undefined +} + +const DEFAULT_STATE_KEY = 'agent_skills' +const RESOURCE_DIRS = ['scripts', 'references', 'assets'] +const DEFAULT_MAX_RESOURCE_FILES = 20 + +/** + * Escape XML special characters in text content. + */ +function escapeXml(text: string): string { + return text + .replace(/&/g, '&') + .replace(//g, '>') + .replace(/"/g, '"') + .replace(/'/g, ''') +} + +/** + * Plugin that integrates Agent Skills into a Strands agent. + * + * Provides: + * 1. A `skills` tool that allows the agent to activate skills on demand + * 2. System prompt injection of available skill metadata before each invocation + * 3. Session persistence of activated skill state via `agent.appState` + * + * Skills can be provided as filesystem paths (to individual skill directories or + * parent directories containing multiple skills), HTTPS URLs pointing to raw + * SKILL.md content, or as pre-built `Skill` instances. + * + * @example + * ```typescript + * import { Agent } from '@strands-agents/sdk' + * import { Skill, AgentSkills } from '@strands-agents/sdk/vended-plugins/skills' + * + * // Load from filesystem + * const plugin = new AgentSkills({ + * skills: ['./skills/pdf-processing', './skills/'], + * }) + * + * // Or provide Skill instances directly + * const skill = new Skill({ name: 'my-skill', description: 'A custom skill', instructions: 'Do the thing' }) + * const plugin = new AgentSkills({ skills: [skill] }) + * + * const agent = new Agent({ model, plugins: [plugin] }) + * ``` + */ +export class AgentSkills implements Plugin { + readonly name = 'strands:agent-skills' + + private _skills: Map + private readonly _maxResourceFiles: number + /** When true, skill validation errors throw instead of logging warnings. */ + private readonly _strict: boolean + private readonly _stateKey: string + /** Resolves when all async skill sources (e.g. URLs) have been loaded. */ + private _ready: Promise + + constructor(config: AgentSkillsConfig) { + this._strict = config.strict ?? false + this._maxResourceFiles = config.maxResourceFiles ?? DEFAULT_MAX_RESOURCE_FILES + this._stateKey = config.stateKey ?? DEFAULT_STATE_KEY + const { skills, ready } = this._resolveSkills(config.skills) + this._skills = skills + this._ready = ready + } + + /** + * Initialize the plugin with the agent instance. + * + * Waits for any async skill sources (e.g. URLs) to finish loading, then + * registers a BeforeInvocationEvent hook that injects skill metadata + * into the system prompt before each invocation. + */ + async initAgent(agent: LocalAgent): Promise { + await this._ready + + if (this._skills.size === 0) { + logger.warn('no skills were loaded, the agent will have no skills available') + } + logger.debug(`skill_count=<${this._skills.size}> | skills plugin initialized`) + + agent.addHook(BeforeInvocationEvent, async (event) => { + await this._ready + this._injectSkillsXml(event.agent) + }) + } + + /** + * Returns the skills activation tool for auto-registration with the agent. + */ + getTools(): Tool[] { + return [this._createSkillsTool()] + } + + /** + * Get the list of available skills. + */ + async getAvailableSkills(): Promise { + await this._ready + return [...this._skills.values()] + } + + /** + * Replace all available skills. + * + * Each element can be a `Skill` instance, a path to a skill directory + * (containing SKILL.md), a path to a parent directory containing skill + * subdirectories, or an `https://` URL pointing directly to raw SKILL.md + * content. + * + * Note: this does not persist state or deactivate skills on any agent. + * Active skill state is managed per-agent and will be reconciled on the + * next tool call or invocation. + */ + setAvailableSkills(skills: SkillSource[]): void { + const { skills: resolved, ready } = this._resolveSkills(skills) + this._skills = resolved + this._ready = ready + } + + /** + * Get the list of skills activated by the given agent. + * Returns skill names in activation order (most recent last). + */ + getActivatedSkills(agent: LocalAgent): readonly string[] { + return (this._getStateField(agent, 'activatedSkills') as string[] | undefined) ?? [] + } + + /** + * Resolve a list of skill sources into Skill instances. + * + * Each source can be a Skill instance, a path to a skill directory, + * a path to a parent directory containing multiple skills, or an + * HTTPS URL pointing to a SKILL.md file. + * + * Synchronous sources (Skill instances and filesystem paths) are resolved + * immediately into the returned map. Async sources (URLs) are resolved in + * the background; the returned `ready` promise resolves when all URL + * fetches have completed and the map has been updated. + */ + private _resolveSkills(sources: SkillSource[]): { skills: Map; ready: Promise } { + const resolved = new Map() + const asyncTasks: Promise[] = [] + + for (const source of sources) { + if (source instanceof Skill) { + if (resolved.has(source.name)) { + logger.warn(`name=<${source.name}> | duplicate skill name, overwriting previous skill`) + } + resolved.set(source.name, source) + } else if (typeof source === 'string' && source.startsWith('https://')) { + asyncTasks.push( + Skill.fromUrl(source, { strict: this._strict }).then( + (skill) => { + if (resolved.has(skill.name)) { + logger.warn(`name=<${skill.name}> | duplicate skill name, overwriting previous skill`) + } + resolved.set(skill.name, skill) + }, + (error) => { + logger.warn(`url=<${source}> | failed to load skill from URL: ${error}`) + } + ) + ) + } else { + const p = source as string + const resolvedPath = resolve(p) + + // Probe the filesystem to decide which loader to use instead of + // relying on exceptions for control flow. + const isDir = existsSync(resolvedPath) && statSync(resolvedPath).isDirectory() + const isSkillFile = + existsSync(resolvedPath) && statSync(resolvedPath).isFile() && resolvedPath.toLowerCase().endsWith('skill.md') + const hasSkillMd = + isDir && + ['SKILL.md', 'skill.md'].some((name) => { + const candidate = join(resolvedPath, name) + return existsSync(candidate) && statSync(candidate).isFile() + }) + + if (isSkillFile || hasSkillMd) { + // Single skill directory (or direct SKILL.md path) + try { + const skill = Skill.fromFile(p, { strict: this._strict }) + if (resolved.has(skill.name)) { + logger.warn(`name=<${skill.name}> | duplicate skill name, overwriting previous skill`) + } + resolved.set(skill.name, skill) + } catch (error) { + logger.warn(`path=<${p}> | failed to load skill: ${error}`) + } + } else if (isDir) { + // Parent directory containing skill subdirectories + try { + for (const skill of Skill.fromDirectory(p, { strict: this._strict })) { + if (resolved.has(skill.name)) { + logger.warn(`name=<${skill.name}> | duplicate skill name, overwriting previous skill`) + } + resolved.set(skill.name, skill) + } + } catch (error) { + logger.warn(`path=<${p}> | failed to load skills from directory: ${error}`) + } + } else { + logger.warn(`path=<${p}> | skill source does not exist or is not a valid path`) + } + } + } + + let ready: Promise + if (asyncTasks.length > 0) { + ready = Promise.all(asyncTasks).then(() => { + logger.debug( + `source_count=<${sources.length}>, resolved_count=<${resolved.size}> | skills resolved (including async)` + ) + }) + } else { + logger.debug(`source_count=<${sources.length}>, resolved_count=<${resolved.size}> | skills resolved`) + ready = Promise.resolve() + } + + return { skills: resolved, ready } + } + + /** + * Create the skills activation tool using the tool() factory with Zod schema. + */ + private _createSkillsTool(): Tool { + return tool({ + name: 'skills', + description: + 'Activate a skill to load its full instructions. ' + + 'Use this tool to load the complete instructions for a skill listed in ' + + 'the available_skills section of your system prompt.', + inputSchema: z.object({ + skill_name: z.string().min(1).describe('Name of the skill to activate'), + }), + callback: async (input: { skill_name: string }, context?: ToolContext): Promise => { + if (context == null) { + throw new Error('skills tool requires a ToolContext with an agent reference') + } + await this._ready + return this._activateSkill(input.skill_name, context) + }, + }) + } + + /** + * Handle skill activation from the tool callback. + */ + private _activateSkill(skillName: string, context: ToolContext): string { + const found = this._skills.get(skillName) + if (found == null) { + const available = [...this._skills.keys()].join(', ') + return `Skill '${skillName}' not found. Available skills: ${available}` + } + + logger.debug(`skill_name=<${skillName}> | skill activated`) + this._trackActivatedSkill(context.agent, skillName) + return this._formatSkillResponse(found) + } + + /** + * Record a skill activation in agent state. + * Maintains an ordered list of activated skill names (most recent last), without duplicates. + */ + private _trackActivatedSkill(agent: LocalAgent, skillName: string): void { + const activated = (this._getStateField(agent, 'activatedSkills') as string[] | undefined) ?? [] + this._setStateField(agent, 'activatedSkills', [...activated.filter((n) => n !== skillName), skillName]) + } + + /** + * Get a field from the plugin's per-agent state dict. + */ + private _getStateField(agent: LocalAgent, key: string): unknown { + const data = agent.appState.get(this._stateKey) + if (data != null && typeof data === 'object' && !Array.isArray(data)) { + return (data as Record)[key] + } + return undefined + } + + /** + * Set a single field in the plugin's per-agent state dict. + */ + private _setStateField(agent: LocalAgent, key: string, value: unknown): void { + const data = agent.appState.get(this._stateKey) + if (data != null && (typeof data !== 'object' || Array.isArray(data))) { + throw new TypeError(`expected object for state key '${this._stateKey}', got ${typeof data}`) + } + const record = (data ?? {}) as Record + record[key] = value + agent.appState.set(this._stateKey, record) + } + + /** + * Inject skill metadata into the agent's system prompt. + * + * Removes the previously injected XML block (if any) via exact string + * replacement, then appends a fresh one. Uses agent state to track the + * injected XML per-agent, so a single plugin instance can be shared + * across multiple agents safely. + */ + private _injectSkillsXml(agent: LocalAgent): void { + const skillsXml = this._generateSkillsXml() + const systemPrompt = agent.systemPrompt + + if (systemPrompt == null || typeof systemPrompt === 'string') { + let currentPrompt = systemPrompt ?? '' + + // Remove previously injected XML by exact match + const lastInjectedXml = this._getStateField(agent, 'lastInjectedXml') as string | undefined + if (lastInjectedXml != null) { + if (currentPrompt.includes(lastInjectedXml)) { + currentPrompt = currentPrompt.replace(lastInjectedXml, '') + } else { + logger.warn('unable to find previously injected skills XML in system prompt, re-appending') + } + } + + const injection = `\n\n${skillsXml}` + const newPrompt = currentPrompt ? `${currentPrompt}${injection}` : skillsXml + const newInjectedXml = currentPrompt ? injection : skillsXml + + this._setStateField(agent, 'lastInjectedXml', newInjectedXml) + agent.systemPrompt = newPrompt + } else { + // SystemContentBlock[] — remove previous block by exact text match, append new one + const lastInjectedXml = this._getStateField(agent, 'lastInjectedXml') as string | undefined + let filtered: SystemContentBlock[] + if (lastInjectedXml != null) { + filtered = systemPrompt.filter((block) => !(block.type === 'textBlock' && block.text === lastInjectedXml)) + if (filtered.length === systemPrompt.length) { + logger.warn('unable to find previously injected skills XML in system prompt, re-appending') + } + } else { + filtered = [...systemPrompt] + } + + this._setStateField(agent, 'lastInjectedXml', skillsXml) + filtered.push(new TextBlock(skillsXml)) + agent.systemPrompt = filtered + } + } + + /** + * Generate the XML block listing available skills for the system prompt. + * + * @example Output with skills: + * ```xml + * + * + * pdf-processing + * Extract text and tables from PDF files + * /path/to/pdf-processing/SKILL.md + * + * + * ``` + */ + private _generateSkillsXml(): string { + if (this._skills.size === 0) { + return '\nNo skills are currently available.\n' + } + + const lines: string[] = [''] + + for (const skill of this._skills.values()) { + lines.push('') + lines.push(`${escapeXml(skill.name)}`) + lines.push(`${escapeXml(skill.description)}`) + if (skill.path != null) { + lines.push(`${escapeXml(join(skill.path, 'SKILL.md'))}`) + } + lines.push('') + } + + lines.push('') + return lines.join('\n') + } + + /** + * Format the tool response when a skill is activated. + * + * Includes the full instructions along with relevant metadata fields + * and a listing of available resource files. + */ + private _formatSkillResponse(skill: Skill): string { + if (!skill.instructions) { + return `Skill '${skill.name}' activated (no instructions available).` + } + + const parts: string[] = [skill.instructions] + + const metadataLines: string[] = [] + if (skill.allowedTools != null && skill.allowedTools.length > 0) { + metadataLines.push(`Allowed tools: ${skill.allowedTools.join(', ')}`) + } + if (skill.compatibility != null) { + metadataLines.push(`Compatibility: ${skill.compatibility}`) + } + if (skill.path != null) { + metadataLines.push(`Location: ${join(skill.path, 'SKILL.md')}`) + } + + if (metadataLines.length > 0) { + parts.push('\n---\n' + metadataLines.join('\n')) + } + + if (skill.path != null) { + const resources = this._listSkillResources(skill.path) + if (resources.length > 0) { + parts.push('\nAvailable resources:\n' + resources.map((r) => ` ${r}`).join('\n')) + } + } + + return parts.join('\n') + } + + /** + * List resource files in a skill's optional directories. + * + * Scans `scripts/`, `references/`, and `assets/` subdirectories for files, + * returning relative paths. Results are capped at maxResourceFiles. + */ + private _listSkillResources(skillPath: string): string[] { + const files: string[] = [] + + for (const dirName of RESOURCE_DIRS) { + const resourceDir = join(skillPath, dirName) + if (!existsSync(resourceDir) || !statSync(resourceDir).isDirectory()) { + continue + } + + const entries = readdirSync(resourceDir, { recursive: true, encoding: 'utf-8' }) + for (const entry of entries.sort()) { + const fullPath = join(resourceDir, entry) + if (!existsSync(fullPath) || !statSync(fullPath).isFile()) continue + + files.push(relative(skillPath, fullPath).split(sep).join('/')) + if (files.length >= this._maxResourceFiles) { + files.push(`... (truncated at ${this._maxResourceFiles} files)`) + return files + } + } + } + + return files + } +} diff --git a/strands-ts/src/vended-plugins/skills/index.ts b/strands-ts/src/vended-plugins/skills/index.ts new file mode 100644 index 0000000000..ed1fe957db --- /dev/null +++ b/strands-ts/src/vended-plugins/skills/index.ts @@ -0,0 +1,31 @@ +/** + * AgentSkills.io integration for Strands Agents. + * + * This module provides the AgentSkills plugin and Skill data model for + * loading and managing AgentSkills.io skills. Skills enable progressive + * disclosure of instructions: metadata is injected into the system prompt + * upfront, and full instructions are loaded on demand via a tool. + * + * @example + * ```typescript + * import { Agent } from '@strands-agents/sdk' + * import { Skill, AgentSkills } from '@strands-agents/sdk/vended-plugins/skills' + * + * // Load from filesystem + * const plugin = new AgentSkills({ + * skills: ['./skills/pdf-processing', './skills/'], + * }) + * + * // Or provide Skill instances directly + * const skill = new Skill({ name: 'my-skill', description: 'A custom skill', instructions: 'Do the thing' }) + * const plugin = new AgentSkills({ skills: [skill] }) + * + * const agent = new Agent({ model, plugins: [plugin] }) + * ``` + */ + +export { Skill } from './skill.js' +export type { SkillConfig } from './skill.js' + +export { AgentSkills } from './agent-skills.js' +export type { AgentSkillsConfig, SkillSource } from './agent-skills.js' diff --git a/strands-ts/src/vended-plugins/skills/skill.ts b/strands-ts/src/vended-plugins/skills/skill.ts new file mode 100644 index 0000000000..51e5c43e4a --- /dev/null +++ b/strands-ts/src/vended-plugins/skills/skill.ts @@ -0,0 +1,438 @@ +/** + * Skill data model and loading utilities for AgentSkills.io skills. + * + * This module defines the Skill class and provides static methods for + * discovering, parsing, and loading skills from the filesystem, raw content, + * or HTTPS URLs. Skills are directories containing a SKILL.md file with YAML + * frontmatter metadata and markdown instructions. + */ + +import { readFileSync, readdirSync, statSync, existsSync } from 'fs' +import { resolve, join, basename } from 'path' +import { parse as parseYaml } from 'yaml' +import { logger } from '../../logging/logger.js' + +const SKILL_NAME_PATTERN = /^[a-z0-9]([a-z0-9-]*[a-z0-9])?$/ +const MAX_SKILL_NAME_LENGTH = 64 +const SKILL_HTTPS_FETCH_TIMEOUT_MS = 30_000 + +/** + * Configuration for creating a Skill instance. + */ +export interface SkillConfig { + /** Unique identifier for the skill (1-64 chars, lowercase alphanumeric + hyphens). */ + name: string + /** Human-readable description of what the skill does. */ + description: string + /** Full markdown instructions from the SKILL.md body. */ + instructions?: string | undefined + /** Filesystem path to the skill directory, if loaded from disk. */ + path?: string | undefined + /** List of tool names the skill is allowed to use. (Experimental: not yet enforced) */ + allowedTools?: string[] | undefined + /** Additional key-value metadata from the SKILL.md frontmatter. */ + metadata?: Record | undefined + /** License identifier (e.g., "Apache-2.0"). */ + license?: string | undefined + /** Compatibility information string. */ + compatibility?: string | undefined +} + +/** + * Find the SKILL.md file in a skill directory. + * + * Searches for SKILL.md (case-sensitive preferred) or skill.md as a fallback. + */ +function findSkillMd(skillDir: string): string { + for (const name of ['SKILL.md', 'skill.md']) { + const candidate = join(skillDir, name) + if (existsSync(candidate) && statSync(candidate).isFile()) { + return candidate + } + } + throw new Error(`path=<${skillDir}> | no SKILL.md found in skill directory`) +} + +/** + * Parse YAML frontmatter and body from SKILL.md content. + * + * Extracts the YAML frontmatter between `---` delimiters and returns + * parsed key-value pairs along with the remaining markdown body. + */ +function parseFrontmatter(content: string): { frontmatter: Record; body: string } { + const stripped = content.trim() + if (!stripped.startsWith('---')) { + throw new Error('SKILL.md must start with --- frontmatter delimiter') + } + + // Find the closing --- delimiter (first line after the opener that is only dashes) + const match = stripped.substring(3).match(/\n^---\s*$/m) + if (match == null || match.index == null) { + throw new Error('SKILL.md frontmatter missing closing --- delimiter') + } + + const frontmatterStr = stripped.substring(3, match.index + 3).trim() + const body = stripped.substring(match.index + 3 + match[0].length).trim() + + let result: unknown + try { + result = parseYaml(frontmatterStr) + } catch { + // AgentSkills spec recommends handling malformed YAML (e.g. unquoted colons in values) + // to improve cross-client compatibility. + logger.warn('YAML parse failed, retrying with colon-quoting fallback') + const fixed = fixYamlColons(frontmatterStr) + result = parseYaml(fixed) + } + + const frontmatter: Record = + typeof result === 'object' && result !== null ? (result as Record) : {} + return { frontmatter, body } +} + +/** + * Attempt to fix common YAML issues like unquoted colons in values. + * + * Wraps values containing colons in double quotes to handle cases like: + * `description: Use this skill when: the user asks about PDFs` + */ +function fixYamlColons(yamlStr: string): string { + return yamlStr + .split('\n') + .map((line) => { + const match = line.match(/^(\s*\w[\w-]*):\s+(.+)$/) + if (match) { + const [, key, value] = match + if (value && value.includes(':') && !value.startsWith('"') && !value.startsWith("'")) { + // Escape backslashes and double-quotes inside the value before wrapping, + // otherwise values like `Use when: user says "hello"` produce broken YAML. + const escaped = value.replace(/\\/g, '\\\\').replace(/"/g, '\\"') + return `${key}: "${escaped}"` + } + } + return line + }) + .join('\n') +} + +/** + * Validate a skill name per the AgentSkills.io specification. + * + * In lenient mode (default), logs warnings for cosmetic issues but does not throw. + * In strict mode, throws Error for any validation failure. + * + * Rules checked: + * - 1-64 characters long + * - Lowercase alphanumeric characters and hyphens only + * - Cannot start or end with a hyphen + * - No consecutive hyphens + * - Must match parent directory name (if loaded from disk) + */ +function validateSkillName(name: string, dirPath?: string, options?: { strict?: boolean }): void { + const strict = options?.strict ?? false + + if (!name) { + throw new Error('Skill name cannot be empty') + } + + if (name.length > MAX_SKILL_NAME_LENGTH) { + const msg = `name=<${name}> | skill name exceeds ${MAX_SKILL_NAME_LENGTH} character limit` + if (strict) throw new Error(msg) + logger.warn(msg) + } + + if (!SKILL_NAME_PATTERN.test(name)) { + const msg = `name=<${name}> | skill name should be 1-64 lowercase alphanumeric characters or hyphens, should not start/end with hyphen` + if (strict) throw new Error(msg) + logger.warn(msg) + } + + if (name.includes('--')) { + const msg = `name=<${name}> | skill name contains consecutive hyphens` + if (strict) throw new Error(msg) + logger.warn(msg) + } + + if (dirPath != null && basename(dirPath) !== name) { + const msg = `name=<${name}>, directory=<${basename(dirPath)}> | skill name does not match parent directory name` + if (strict) throw new Error(msg) + logger.warn(msg) + } +} + +/** + * Build a Skill instance from parsed frontmatter and body. + */ +function buildSkillFromFrontmatter( + frontmatter: Record, + body: string, + path?: string | undefined +): Skill { + // Parse allowed-tools (space-delimited string or YAML list) + const allowedToolsRaw = (frontmatter['allowed-tools'] ?? frontmatter['allowed_tools']) as + | string + | unknown[] + | undefined + let allowedTools: string[] | undefined + if (typeof allowedToolsRaw === 'string' && allowedToolsRaw.trim()) { + allowedTools = allowedToolsRaw.trim().split(/\s+/) + } else if (Array.isArray(allowedToolsRaw)) { + allowedTools = allowedToolsRaw.filter((item) => item != null).map(String) + } + + // Parse metadata (nested mapping) + const metadataRaw = frontmatter['metadata'] + const metadata: Record = {} + if (typeof metadataRaw === 'object' && metadataRaw !== null && !Array.isArray(metadataRaw)) { + for (const [k, v] of Object.entries(metadataRaw)) { + metadata[String(k)] = v + } + } + + const skillLicense = frontmatter['license'] + const compatibility = frontmatter['compatibility'] + + return new Skill({ + name: frontmatter['name'] as string, + description: frontmatter['description'] as string, + instructions: body, + path, + allowedTools, + metadata, + license: skillLicense != null ? String(skillLicense) : undefined, + compatibility: compatibility != null ? String(compatibility) : undefined, + }) +} + +/** + * Represents an agent skill with metadata and instructions. + * + * A skill encapsulates a set of instructions and metadata that can be + * dynamically loaded by an agent at runtime. Skills support progressive + * disclosure: metadata is shown upfront in the system prompt, and full + * instructions are loaded on demand via a tool. + * + * Skills can be created directly or via convenience static methods: + * + * @example + * ```typescript + * // From a skill directory on disk + * const skill = Skill.fromFile('./skills/my-skill') + * + * // From raw SKILL.md content + * const skill = Skill.fromContent('---\nname: my-skill\n...') + * + * // Load all skills from a parent directory + * const skills = Skill.fromDirectory('./skills/') + * + * // From an HTTPS URL + * const skill = await Skill.fromUrl('https://example.com/SKILL.md') + * ``` + */ +export class Skill { + /** Unique identifier for the skill (1-64 chars, lowercase alphanumeric + hyphens). */ + readonly name: string + + /** Human-readable description of what the skill does. */ + readonly description: string + + /** Full markdown instructions from the SKILL.md body. */ + readonly instructions: string + + /** Filesystem path to the skill directory, if loaded from disk. */ + readonly path: string | undefined + + /** List of tool names the skill is allowed to use. (Experimental: not yet enforced) */ + readonly allowedTools: string[] | undefined + + /** Additional key-value metadata from the SKILL.md frontmatter. */ + readonly metadata: Record + + /** License identifier (e.g., "Apache-2.0"). */ + readonly license: string | undefined + + /** Compatibility information string. */ + readonly compatibility: string | undefined + + constructor(config: SkillConfig) { + this.name = config.name + this.description = config.description + this.instructions = config.instructions ?? '' + this.path = config.path + this.allowedTools = config.allowedTools + this.metadata = config.metadata ?? {} + this.license = config.license + this.compatibility = config.compatibility + } + + /** + * Load a single skill from a directory containing SKILL.md. + * + * Resolves the filesystem path, reads the file content, and delegates + * to {@link fromContent} for parsing. After loading, sets the skill's + * `path` and validates the skill name against the parent directory. + * + * @param skillPath - Path to the skill directory or the SKILL.md file itself. + * @param options - Optional settings. When `strict` is true, throws on any validation issue; otherwise warns and loads anyway. + * @returns A Skill instance populated from the SKILL.md file. + */ + static fromFile(skillPath: string, options?: { strict?: boolean }): Skill { + const resolvedPath = resolve(skillPath) + + let skillMdPath: string + let skillDir: string + + if ( + existsSync(resolvedPath) && + statSync(resolvedPath).isFile() && + basename(resolvedPath).toLowerCase() === 'skill.md' + ) { + skillMdPath = resolvedPath + skillDir = resolve(resolvedPath, '..') + } else if (existsSync(resolvedPath) && statSync(resolvedPath).isDirectory()) { + skillDir = resolvedPath + skillMdPath = findSkillMd(skillDir) + } else { + throw new Error(`path=<${resolvedPath}> | skill path does not exist or is not a valid skill directory`) + } + + logger.debug(`path=<${skillMdPath}> | loading skill`) + + const content = readFileSync(skillMdPath, 'utf-8') + const skill = Skill.fromContent(content, { ...options, path: skillDir }) + + logger.debug(`name=<${skill.name}>, path=<${skill.path}> | skill loaded successfully`) + return skill + } + + /** + * Parse SKILL.md content into a Skill instance. + * + * Creates a Skill from raw SKILL.md content (YAML frontmatter + markdown body) + * without requiring a file on disk. + * + * @example + * ```typescript + * const content = `--- + * name: my-skill + * description: Does something useful + * --- + * # Instructions + * Follow these steps...` + * + * const skill = Skill.fromContent(content) + * ``` + * + * @param content - Raw SKILL.md content with YAML frontmatter and markdown body. + * @param options - Optional settings. When `strict` is true, throws on any validation issue; otherwise warns and loads anyway. + * @returns A Skill instance populated from the parsed content. + */ + static fromContent(content: string, options?: { strict?: boolean; path?: string | undefined }): Skill { + const strict = options?.strict ?? false + const { frontmatter, body } = parseFrontmatter(content) + + const name = frontmatter['name'] + if (typeof name !== 'string' || !name) { + throw new Error("SKILL.md content must have a 'name' field in frontmatter") + } + + const description = frontmatter['description'] + if (typeof description !== 'string' || !description) { + throw new Error("SKILL.md content must have a 'description' field in frontmatter") + } + + validateSkillName(name, options?.path, { strict }) + + return buildSkillFromFrontmatter(frontmatter, body, options?.path) + } + + /** + * Load a skill by fetching its SKILL.md content from an HTTPS URL. + * + * Fetches the raw SKILL.md content over HTTPS and parses it using + * {@link fromContent}. The URL must point directly to the raw file + * content (not an HTML page). + * + * @example + * ```typescript + * const skill = await Skill.fromUrl( + * 'https://raw.githubusercontent.com/org/repo/main/SKILL.md' + * ) + * ``` + * + * @param url - An `https://` URL pointing directly to raw SKILL.md content. + * @param options - Optional settings. When `strict` is true, throws on any validation issue; otherwise warns and loads anyway. + * @returns A Promise resolving to a Skill instance populated from the fetched SKILL.md content. + * @throws If `url` is not an `https://` URL. + * @throws If the SKILL.md content cannot be fetched. + */ + static async fromUrl(url: string, options?: { strict?: boolean }): Promise { + if (!url.startsWith('https://')) { + throw new Error(`url=<${url}> | not a valid HTTPS URL`) + } + + logger.info(`url=<${url}> | fetching skill content`) + + let content: string + try { + const response = await globalThis.fetch(url, { + headers: { 'User-Agent': 'strands-agents-sdk' }, + signal: AbortSignal.timeout(SKILL_HTTPS_FETCH_TIMEOUT_MS), + }) + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + content = await response.text() + } catch (error) { + if (error instanceof Error && error.message.startsWith('HTTP ')) { + throw new Error(`url=<${url}> | ${error.message}`) + } + throw new Error(`url=<${url}> | failed to fetch skill: ${error instanceof Error ? error.message : error}`) + } + + return Skill.fromContent(content, options) + } + + /** + * Load all skills from a parent directory containing skill subdirectories. + * + * Each subdirectory containing a SKILL.md file is treated as a skill. + * Subdirectories without SKILL.md are silently skipped. + * + * @param skillsDir - Path to the parent directory containing skill subdirectories. + * @param options - Optional settings. When `strict` is true, throws on any validation issue; otherwise warns and loads anyway. + * @returns List of Skill instances loaded from the directory. + */ + static fromDirectory(skillsDir: string, options?: { strict?: boolean }): Skill[] { + const resolvedDir = resolve(skillsDir) + + if (!existsSync(resolvedDir) || !statSync(resolvedDir).isDirectory()) { + throw new Error(`path=<${resolvedDir}> | skills directory does not exist`) + } + + const skills: Skill[] = [] + const children = readdirSync(resolvedDir).sort() + + for (const child of children) { + const childPath = join(resolvedDir, child) + if (!existsSync(childPath) || !statSync(childPath).isDirectory()) continue + + try { + findSkillMd(childPath) + } catch { + logger.debug(`path=<${childPath}> | skipping directory without SKILL.md`) + continue + } + + try { + const skill = Skill.fromFile(childPath, options) + skills.push(skill) + } catch (error) { + logger.warn(`path=<${childPath}> | skipping skill due to error: ${error}`) + } + } + + logger.debug(`path=<${resolvedDir}>, count=<${skills.length}> | loaded skills from directory`) + return skills + } +} diff --git a/strands-ts/src/vended-tools/bash/README.md b/strands-ts/src/vended-tools/bash/README.md new file mode 100644 index 0000000000..9bc71225d3 --- /dev/null +++ b/strands-ts/src/vended-tools/bash/README.md @@ -0,0 +1,202 @@ +# Bash Tool + +A robust tool for executing bash shell commands in Node.js environments with persistent session support. + +## ⚠️ Security Warning + +**This tool executes arbitrary bash commands without sandboxing or restrictions.** + +- Only use with trusted input +- Commands execute with the permissions of the Node.js process +- Environment variables are inherited from the parent process +- For production deployments, consider running in a sandboxed environment (containers, VMs, etc.) +- Review all commands before execution +- Never expose this tool to untrusted users without additional security measures + +## Requirements + +**Node.js Only**: This tool requires Node.js and uses the `child_process` module. It will not work in browser environments. + +**Unix/Linux/macOS Only**: This tool uses the `bash` shell and is designed for Unix-like systems. It does not currently support Windows environments. + +## Features + +- **Persistent Sessions**: Commands execute in a persistent bash session, maintaining state (variables, working directory, etc.) across multiple invocations +- **Separate Output Streams**: Captures stdout and stderr independently +- **Configurable Timeouts**: Prevent commands from hanging indefinitely (default: 120 seconds) +- **Session Management**: Restart sessions to clear state when needed +- **Isolated Sessions**: Each agent instance gets its own isolated bash session +- **Working Directory**: Inherits the working directory from `process.cwd()` + +## Installation + +```typescript +import { bash } from '@strands-agents/sdk/vended-tools/bash' +``` + +## Usage + +### With an Agent + +```typescript +import { Agent } from '@strands-agents/sdk' +import { BedrockModel } from '@strands-agents/sdk' +import { bash } from '@strands-agents/sdk/vended-tools/bash' + +const agent = new Agent({ + model: new BedrockModel({ + region: 'us-east-1', + }), + tools: [bash], +}) + +// The agent can now use the bash tool +await agent.invoke('List all files in the current directory') +await agent.invoke('Create a new file called notes.txt with "Hello World"') +``` + +### Session Persistence + +Variables, functions, and working directory persist across commands in the same session: + +```typescript +import { Agent } from '@strands-agents/sdk' +import { BedrockModel } from '@strands-agents/sdk' +import { bash } from '@strands-agents/sdk/vended-tools/bash' + +const model = new BedrockModel({ + region: 'us-east-1', +}) + +const agent = new Agent({ + model, + tools: [bash], +}) + +let res +res = await agent.invoke('run export "MY_VAR=hello"') +console.log(res.lastMessage) + +res = await agent.invoke('run "echo $MY_VAR"') +console.log(res.lastMessage) // Will show "hello" +``` + +### Restart Session + +Clear all session state and start fresh: + +```typescript +import { Agent } from '@strands-agents/sdk' +import { BedrockModel } from '@strands-agents/sdk' +import { bash } from '@strands-agents/sdk/vended-tools/bash' + +const model = new BedrockModel({ + region: 'us-east-1', +}) + +const agent = new Agent({ + model, + tools: [bash], +}) + +// Set a variable +let res = await agent.invoke('run export "TEMP_VAR=exists"') + +// Restart the session +res = await agent.invoke('restart the bash session') + +// Variable is now gone +res = await agent.invoke('run "echo $TEMP_VAR"') +console.log(res.lastMessage) // Variable will be empty/undefined +``` + +## API Reference + +### Input Schema + +#### Execute Mode + +```typescript +interface ExecuteInput { + mode: 'execute' + command: string + timeout?: number // Optional timeout in seconds (default: 120) +} +``` + +#### Restart Mode + +```typescript +interface RestartInput { + mode: 'restart' +} +``` + +### Return Value + +#### Execute Mode + +Returns an object with separate stdout and stderr: + +```typescript +interface BashOutput { + output: string // Standard output (stdout) + error: string // Standard error (stderr) - empty string if no errors +} +``` + +### Error Handling + +The tool throws custom errors for specific failure scenarios: + +- **`BashTimeoutError`**: Thrown when a command exceeds its timeout +- **`BashSessionError`**: Thrown when the bash process encounters an error + +```typescript +import { BashTimeoutError, BashSessionError } from '@strands-agents/sdk/vended-tools/bash' + +try { + await bash.invoke({ mode: 'execute', command: 'sleep 1000', timeout: 1 }, context) +} catch (error) { + if (error instanceof BashTimeoutError) { + console.log('Command timed out') + } else if (error instanceof BashSessionError) { + console.log('Session error occurred') + } +} +``` + +## Implementation Details + +### Session Management + +- Each agent instance gets its own isolated bash session +- Sessions are stored in a WeakMap keyed by agent instance +- Sessions automatically clean up when the agent is garbage collected + +### Working Directory + +- The bash process starts in the directory returned by `process.cwd()` +- You can change directories using `cd` commands +- Directory changes persist within the session + +### Timeout Behavior + +- Default timeout is 120 seconds +- Timeout can be configured per-command +- On timeout, the bash process is killed immediately +- A `BashTimeoutError` is thrown + +## Limitations + +- **No browser support**: Cannot run in browser environments +- **Process permissions**: Commands run with the same permissions as the Node.js process +- **No sandboxing**: Commands execute without isolation or restrictions + +## Best Practices + +1. **Always validate input**: Never pass untrusted input directly to commands +2. **Use timeouts**: Set appropriate timeouts for long-running commands +3. **Check stderr**: Always check the `error` field in the return value +4. **Handle errors**: Wrap tool invocations in try-catch blocks +5. **Quote arguments**: Use proper shell quoting for arguments containing spaces or special characters diff --git a/strands-ts/src/vended-tools/bash/__tests__/bash.test.node.ts b/strands-ts/src/vended-tools/bash/__tests__/bash.test.node.ts new file mode 100644 index 0000000000..0baa8ee4ae --- /dev/null +++ b/strands-ts/src/vended-tools/bash/__tests__/bash.test.node.ts @@ -0,0 +1,475 @@ +import { describe, it, expect, vi, afterEach } from 'vitest' +import { bash } from '../index.js' +import { BashTimeoutError, BashSessionError, type BashOutput } from '../index.js' +import type { ToolContext } from '../../../index.js' +import { StateStore } from '../../../state-store.js' +import { createMockAgent } from '../../../__fixtures__/agent-helpers.js' +import { realpathSync } from 'fs' + +// Skip tests on Windows (bash not available) +describe.skipIf(process.platform === 'win32')('bash tool', () => { + // Helper to create fresh context + const createFreshContext = (): { state: StateStore; context: ToolContext } => { + const agent = createMockAgent() + const context: ToolContext = { + toolUse: { + name: 'bash', + toolUseId: 'test-id', + input: {}, + }, + agent, + invocationState: {}, + interrupt: () => { + throw new Error('interrupt not available in mock context') + }, + } + return { state: agent.appState, context } + } + + afterEach(() => { + vi.restoreAllMocks() + }) + + describe('input validation', () => { + it('accepts valid execute command', async () => { + const { context } = createFreshContext() + const result = await bash.invoke({ mode: 'execute', command: 'echo "test"' }, context) + + expect(result).toHaveProperty('output') + expect(result).toHaveProperty('error') + }) + + it('accepts valid restart command', async () => { + const { context } = createFreshContext() + const result = await bash.invoke({ mode: 'restart' }, context) + expect(result).toBe('Bash session restarted') + }) + + it('rejects invalid mode', async () => { + const { context } = createFreshContext() + await expect( + // @ts-expect-error - Testing invalid input + bash.invoke({ mode: 'invalid' }, context) + ).rejects.toThrow() + }) + + it('rejects execute without command', async () => { + const { context } = createFreshContext() + await expect(bash.invoke({ mode: 'execute' }, context)).rejects.toThrow() + }) + + it('accepts valid timeout configuration', async () => { + const { context } = createFreshContext() + const result = await bash.invoke({ mode: 'execute', command: 'echo "fast"', timeout: 300 }, context) + + expect(result).toHaveProperty('output') + }) + + it('rejects negative timeout', async () => { + const { context } = createFreshContext() + await expect(bash.invoke({ mode: 'execute', command: 'echo test', timeout: -1 }, context)).rejects.toThrow() + }) + }) + + describe('session lifecycle', () => { + it('creates session on first execute', async () => { + const { context } = createFreshContext() + const result = await bash.invoke({ mode: 'execute', command: 'echo "test"' }, context) + + expect(result).toHaveProperty('output') + expect((result as BashOutput).output).toContain('test') + }) + + it('creates new session after restart', async () => { + const { context } = createFreshContext() + + // Set variable + await bash.invoke({ mode: 'execute', command: 'TEST_RESTART="exists"' }, context) + + // Restart + const restartResult = await bash.invoke({ mode: 'restart' }, context) + expect(restartResult).toBe('Bash session restarted') + + // Variable should be gone + const afterRestart = await bash.invoke({ mode: 'execute', command: 'echo $TEST_RESTART' }, context) + + expect((afterRestart as BashOutput).output.trim()).not.toContain('exists') + }) + + it('restarts existing session when restart is called', async () => { + const { context } = createFreshContext() + + // First create a session by executing a command + await bash.invoke({ mode: 'execute', command: 'TEST_VAR="initial"' }, context) + + // Now restart the existing session + const restartResult = await bash.invoke({ mode: 'restart' }, context) + expect(restartResult).toBe('Bash session restarted') + + // Verify the variable is gone after restart + const result = await bash.invoke({ mode: 'execute', command: 'echo "${TEST_VAR:-empty}"' }, context) + expect((result as BashOutput).output.trim()).toBe('empty') + }) + + it('persists environment variables between calls', async () => { + const { context } = createFreshContext() + + await bash.invoke({ mode: 'execute', command: 'MY_VAR="persistent_value"' }, context) + const result = await bash.invoke({ mode: 'execute', command: 'echo $MY_VAR' }, context) + + expect((result as BashOutput).output.trim()).toBe('persistent_value') + }) + + it('persists working directory between calls', async () => { + const { context } = createFreshContext() + + await bash.invoke({ mode: 'execute', command: 'cd /tmp' }, context) + const result = await bash.invoke({ mode: 'execute', command: 'pwd' }, context) + + expect(realpathSync((result as BashOutput).output.trim())).toBe(realpathSync('/tmp')) + }) + + it('provides isolated sessions for different agents', async () => { + const { context: context1 } = createFreshContext() + const { context: context2 } = createFreshContext() + + // Set variable in first agent + await bash.invoke({ mode: 'execute', command: 'AGENT_VAR="agent1"' }, context1) + + // Check it's not in second agent + const result = await bash.invoke({ mode: 'execute', command: 'echo $AGENT_VAR' }, context2) + + expect((result as BashOutput).output.trim()).not.toContain('agent1') + }) + + it('handles session restart with no existing session gracefully', async () => { + const { context } = createFreshContext() + + // Restart when no session exists + const result = await bash.invoke({ mode: 'restart' }, context) + expect(result).toBe('Bash session restarted') + + // Should still be able to execute commands + const execResult = await bash.invoke({ mode: 'execute', command: 'echo "works"' }, context) + expect((execResult as BashOutput).output.trim()).toBe('works') + }) + }) + + describe('command execution', () => { + it('executes command and returns output', async () => { + const { context } = createFreshContext() + const result = await bash.invoke({ mode: 'execute', command: 'echo "Hello World"' }, context) + + expect((result as BashOutput).output).toContain('Hello World') + expect((result as BashOutput).error).toBe('') + }) + + it('returns empty stderr on success', async () => { + const { context } = createFreshContext() + const result = await bash.invoke({ mode: 'execute', command: 'echo "success"' }, context) + + expect((result as BashOutput).error).toBe('') + }) + + it('captures stderr on command error', async () => { + const { context } = createFreshContext() + const result = await bash.invoke({ mode: 'execute', command: 'nonexistent_command_xyz' }, context) + + expect((result as BashOutput).error).toContain('not found') + }) + }) + + describe('timeout handling', () => { + it('completes command before timeout', async () => { + const { context } = createFreshContext() + const result = await bash.invoke({ mode: 'execute', command: 'echo "fast"', timeout: 5 }, context) + + expect((result as BashOutput).output).toContain('fast') + }) + + it('throws BashTimeoutError when command times out', async () => { + const { context } = createFreshContext() + + await expect(bash.invoke({ mode: 'execute', command: 'sleep 10', timeout: 0.1 }, context)).rejects.toThrow( + BashTimeoutError + ) + }) + + it('uses default timeout of 120 seconds', async () => { + const { context } = createFreshContext() + const result = await bash.invoke({ mode: 'execute', command: 'echo "test"' }, context) + + expect(result).toHaveProperty('output') + }) + + it('respects custom timeout for new session', async () => { + const { context } = createFreshContext() + + // Create session with custom timeout + const result = await bash.invoke({ mode: 'execute', command: 'echo "custom"', timeout: 10 }, context) + + expect((result as BashOutput).output).toContain('custom') + }) + + it('handles timeout during command with large output', async () => { + const { context } = createFreshContext() + + // Command that generates output continuously + await expect( + bash.invoke({ mode: 'execute', command: 'while true; do echo "spam"; done', timeout: 0.1 }, context) + ).rejects.toThrow(BashTimeoutError) + }) + }) + + describe('error handling', () => { + it('requires context for bash operations', async () => { + await expect(bash.invoke({ mode: 'execute', command: 'echo "test"' })).rejects.toThrow('Tool context is required') + }) + + it('validates command is required for execute mode', async () => { + const { context } = createFreshContext() + + await expect(bash.invoke({ mode: 'execute' }, context)).rejects.toThrow( + 'command is required when mode is "execute"' + ) + }) + + it('validates command is required with undefined command', async () => { + const { context } = createFreshContext() + + await expect(bash.invoke({ mode: 'execute', command: undefined }, context)).rejects.toThrow( + 'command is required when mode is "execute"' + ) + }) + + it('validates command is required with empty string', async () => { + const { context } = createFreshContext() + + await expect(bash.invoke({ mode: 'execute', command: '' }, context)).rejects.toThrow( + 'command is required when mode is "execute"' + ) + }) + + it('handles command execution in a session without proper initialization', async () => { + const { context } = createFreshContext() + + // Create a session first + await bash.invoke({ mode: 'execute', command: 'echo "init"' }, context) + + // Then restart to clear it + await bash.invoke({ mode: 'restart' }, context) + + // Try to execute another command - should work as it creates a new session + const result = await bash.invoke({ mode: 'execute', command: 'echo "after restart"' }, context) + + expect((result as BashOutput).output).toContain('after restart') + }) + + it('creates new session when none exists', async () => { + const { context } = createFreshContext() + + // First command should create a new session + const result = await bash.invoke({ mode: 'execute', command: 'echo "first"' }, context) + + expect((result as BashOutput).output).toContain('first') + }) + + it('handles restart when no session exists', async () => { + const { context } = createFreshContext() + + // Restart without existing session should not throw + const result = await bash.invoke({ mode: 'restart' }, context) + expect(result).toBe('Bash session restarted') + }) + + it('properly cleans up session on restart', async () => { + const { context } = createFreshContext() + + // Create session with variable + await bash.invoke({ mode: 'execute', command: 'CLEANUP_TEST="should_be_gone"' }, context) + + // Restart should clear the session + await bash.invoke({ mode: 'restart' }, context) + + // Variable should not exist in new session + const result = await bash.invoke({ mode: 'execute', command: 'echo "${CLEANUP_TEST:-empty}"' }, context) + + expect((result as BashOutput).output.trim()).toBe('empty') + }) + + it('handles multiple restarts in sequence', async () => { + const { context } = createFreshContext() + + // Restart without existing session + const result1 = await bash.invoke({ mode: 'restart' }, context) + expect(result1).toBe('Bash session restarted') + + // Restart again + const result2 = await bash.invoke({ mode: 'restart' }, context) + expect(result2).toBe('Bash session restarted') + + // Should still be able to execute + const execResult = await bash.invoke({ mode: 'execute', command: 'echo "still works"' }, context) + expect((execResult as BashOutput).output).toContain('still works') + }) + + it('handles command with empty output gracefully', async () => { + const { context } = createFreshContext() + const result = await bash.invoke({ mode: 'execute', command: 'true' }, context) + + expect((result as BashOutput).output).toBe('') + expect((result as BashOutput).error).toBe('') + }) + + it('handles command with only whitespace output', async () => { + const { context } = createFreshContext() + const result = await bash.invoke({ mode: 'execute', command: 'echo " "' }, context) + + expect((result as BashOutput).output.trim()).toBe('') + }) + + it('handles very long command output', async () => { + const { context } = createFreshContext() + // Generate a long string + const result = await bash.invoke( + { + mode: 'execute', + command: 'for i in {1..100}; do echo "Line $i of output"; done', + }, + context + ) + + expect((result as BashOutput).output).toContain('Line 1 of output') + expect((result as BashOutput).output).toContain('Line 100 of output') + }) + + it('creates session with default timeout when not specified', async () => { + const { context } = createFreshContext() + + // Execute without timeout parameter + const result = await bash.invoke({ mode: 'execute', command: 'echo "default"' }, context) + + expect((result as BashOutput).output).toContain('default') + }) + }) + + describe('working directory', () => { + it('starts in process.cwd()', async () => { + const { context } = createFreshContext() + const expectedCwd = realpathSync(process.cwd()) + + const result = await bash.invoke({ mode: 'execute', command: 'pwd' }, context) + + expect(realpathSync((result as BashOutput).output)).toContain(expectedCwd) + }) + }) + + describe('tool properties', () => { + it('has correct tool name', () => { + expect(bash.name).toBe('bash') + }) + + it('has description', () => { + expect(bash.description).toBeDefined() + expect(bash.description.length).toBeGreaterThan(0) + }) + + it('has toolSpec', () => { + expect(bash.toolSpec).toBeDefined() + expect(bash.toolSpec.name).toBe('bash') + }) + }) + + describe('error classes', () => { + it('BashTimeoutError has correct properties', () => { + const error = new BashTimeoutError('timeout message') + expect(error.name).toBe('BashTimeoutError') + expect(error.message).toBe('timeout message') + expect(error instanceof Error).toBe(true) + }) + + it('BashSessionError has correct properties', () => { + const error = new BashSessionError('session error message') + expect(error.name).toBe('BashSessionError') + expect(error.message).toBe('session error message') + expect(error instanceof Error).toBe(true) + }) + }) + + describe('module exports', () => { + it('exports bash tool from index', () => { + expect(bash).toBeDefined() + expect(bash.name).toBe('bash') + }) + + it('exports error classes from index', () => { + expect(BashTimeoutError).toBeDefined() + expect(BashSessionError).toBeDefined() + }) + }) + + describe('bash session edge cases', () => { + it('handles process close during command execution', async () => { + const { context } = createFreshContext() + + // Use a command that will make the bash process exit - this should throw an error + await expect(bash.invoke({ mode: 'execute', command: 'exit 0' }, context)).rejects.toThrow(BashSessionError) + + // Next command should work with a new session + const newResult = await bash.invoke({ mode: 'execute', command: 'echo "new session"' }, context) + expect((newResult as BashOutput).output).toContain('new session') + }) + }) + + describe('process cleanup', () => { + it('cleans up on beforeExit event', async () => { + const { context } = createFreshContext() + + // Create a session + await bash.invoke({ mode: 'execute', command: 'echo "test"' }, context) + + // Simulate beforeExit event + process.emit('beforeExit', 0) + + // Session should be cleaned up, next command creates new session + const result = await bash.invoke({ mode: 'execute', command: 'echo "after exit"' }, context) + expect((result as BashOutput).output).toContain('after exit') + }) + + it('cleans up on exit event', async () => { + const { context } = createFreshContext() + + // Create a session + await bash.invoke({ mode: 'execute', command: 'echo "test"' }, context) + + // Simulate exit event + process.emit('exit', 0) + + // Session should be cleaned up + const result = await bash.invoke({ mode: 'execute', command: 'echo "after exit"' }, context) + expect((result as BashOutput).output).toContain('after exit') + }) + + it('cleans up on SIGINT', async () => { + const { context } = createFreshContext() + + // Mock process.exit to prevent actual exit + const exitMock = vi.spyOn(process, 'exit').mockImplementation(() => { + throw new Error('process.exit called') + }) + + // Create a session + await bash.invoke({ mode: 'execute', command: 'echo "test"' }, context) + + // Simulate SIGINT + try { + process.emit('SIGINT') + } catch { + // Expected to throw due to our mock + } + + expect(exitMock).toHaveBeenCalledWith(0) + exitMock.mockRestore() + }) + }) +}) diff --git a/strands-ts/src/vended-tools/bash/bash.ts b/strands-ts/src/vended-tools/bash/bash.ts new file mode 100644 index 0000000000..ccec3c1a3f --- /dev/null +++ b/strands-ts/src/vended-tools/bash/bash.ts @@ -0,0 +1,296 @@ +import { tool } from '../../tools/tool-factory.js' +import { z } from 'zod' +import { spawn, type ChildProcess } from 'child_process' +import { Buffer } from 'buffer' +import type { BashOutput } from './types.js' +import { BashTimeoutError, BashSessionError } from './types.js' + +/** + * Zod schema for bash input validation. + * + * Note: Uses a single object schema instead of discriminated union for AWS Bedrock compatibility. + */ +const bashInputSchema = z.object({ + mode: z + .enum(['execute', 'restart']) + .describe('Operation mode: "execute" to run a command, "restart" to restart the session'), + command: z.string().optional().describe('The bash command to execute (required when mode is "execute")'), + timeout: z.number().positive().optional().describe('Timeout in seconds (default: 120, applies only to execute mode)'), +}) + +/** + * Internal class for managing a bash session. + */ +class BashSession { + private _process: ChildProcess | null = null + private _started = false + private readonly _timeout: number + private readonly _sentinel: string + + constructor(timeout = 120) { + this._timeout = timeout + this._sentinel = `__BASH_DONE_${Date.now()}_${Math.random().toString(36).slice(2)}__` + } + + /** + * Starts the bash process if not already started. + */ + start(): void { + if (this._started) { + return + } + + try { + this._process = spawn('bash', [], { + cwd: process.cwd(), + env: { ...process.env, PS1: '', PS2: '' }, + }) + + if (!this._process.stdin || !this._process.stdout || !this._process.stderr) { + throw new BashSessionError('Failed to create bash process streams') + } + + this._started = true + activeSessions.add(this) + + // Handle unexpected process exits + this._process.on('close', () => { + this._process = null + this._started = false + }) + } catch (err) { + throw new BashSessionError(`Failed to start bash session: ${(err as Error).message}`) + } + } + + /** + * Stops the bash process. + */ + stop(): void { + if (this._process) { + this._process.kill() + this._process = null + this._started = false + } + activeSessions.delete(this) + } + + /** + * Runs a command in the bash session. + */ + async run(command: string, timeout?: number): Promise { + this.start() + + if (!this._process || !this._process.stdin || !this._process.stdout || !this._process.stderr) { + throw new BashSessionError('Bash session not properly initialized') + } + + const effectiveTimeout = timeout ?? this._timeout + let stdoutData = '' + let stderrData = '' + let timeoutHandle: ReturnType | null = null + let isTimedOut = false + + return new Promise((resolve, reject) => { + const stdout = this._process!.stdout! + const stderr = this._process!.stderr! + const stdin = this._process!.stdin! + + // Handlers for stdout + const onStdoutData = (chunk: unknown): void => { + const data = Buffer.from(chunk as Parameters[0]).toString('utf-8') + stdoutData += data + + // Check for sentinel + if (stdoutData.includes(this._sentinel)) { + cleanup() + + // Remove sentinel from output + const output = stdoutData.replace(this._sentinel, '').trim() + const error = stderrData.trim() + + resolve({ output, error }) + } + } + + // Handlers for stderr + const onStderrData = (chunk: unknown): void => { + stderrData += Buffer.from(chunk as Parameters[0]).toString('utf-8') + } + + // Handler for process close + const onClose = (code: number | null): void => { + if (!isTimedOut) { + cleanup() + reject(new BashSessionError(`Bash process exited unexpectedly with code ${code ?? 'unknown'}`)) + } + } + + // Handler for process errors + const onError = (err: Error): void => { + cleanup() + this.stop() + reject(new BashSessionError(`Bash process error: ${err.message}`)) + } + + // Cleanup function - removes per-command listeners and timeout. + // Does NOT stop the process, preserving session state between calls. + const cleanup = (): void => { + if (timeoutHandle !== null) { + clearTimeout(timeoutHandle) + timeoutHandle = null + } + stdout.off('data', onStdoutData) + stderr.off('data', onStderrData) + // Check if process still exists before removing listeners + if (this._process) { + this._process.off('close', onClose) + this._process.off('error', onError) + } + } + + // Set up timeout + timeoutHandle = setTimeout(() => { + isTimedOut = true + cleanup() + this.stop() + reject(new BashTimeoutError(`Command timed out after ${effectiveTimeout} seconds`)) + }, effectiveTimeout * 1000) + + // Attach listeners + stdout.on('data', onStdoutData) + stderr.on('data', onStderrData) + this._process!.on('close', onClose) + this._process!.on('error', onError) + + // Send command with sentinel + try { + stdin.write(`${command}\necho "${this._sentinel}"\n`) + } catch (err) { + cleanup() + this.stop() + reject(new BashSessionError(`Failed to write command: ${(err as Error).message}`)) + } + }) + } +} + +/** + * WeakMap to store bash sessions per agent instance. + */ +// eslint-disable-next-line @typescript-eslint/no-explicit-any +const sessions = new WeakMap() + +/** + * Track all active sessions for cleanup on process exit. + */ +const activeSessions = new Set() + +/** + * Clean up bash sessions when their associated agent is garbage collected. + */ +const sessionFinalizer = new FinalizationRegistry((session) => { + session.stop() +}) + +/** + * Clean up all active bash sessions. + */ +function cleanupAllSessions(): void { + for (const session of activeSessions) { + session.stop() + } + activeSessions.clear() +} + +// Register cleanup handlers for process exit +process.on('beforeExit', () => { + // beforeExit fires when event loop is empty but process is still alive + // This is our chance to clean up bash processes before they prevent exit + cleanupAllSessions() +}) +process.on('exit', cleanupAllSessions) +process.on('SIGINT', () => { + cleanupAllSessions() + /* c8 ignore next */ + process.exit(0) +}) +/* c8 ignore start */ +process.on('SIGTERM', () => { + cleanupAllSessions() + process.exit(0) +}) +/* c8 ignore stop */ + +/** + * Bash tool for executing shell commands in Node.js environments. + * + * This tool provides a persistent bash session that can execute commands and maintain state + * across multiple invocations within the same agent session. + * + * **Security Warning**: This tool executes arbitrary bash commands without sandboxing. + * Only use with trusted input and consider sandboxing for production deployments. + * + * **Node.js Only**: This tool requires Node.js and the `child_process` module. + * It will not work in browser environments. + * + * @example + * ```typescript + * // With agent + * const agent = new Agent({ tools: [bash] }) + * await agent.invoke('List files in the current directory') + * + * // Direct usage + * const result = await bash.invoke( + * { mode: 'execute', command: 'echo "Hello"' }, + * context + * ) + * console.log(result.output) // "Hello" + * ``` + */ +export const bash = tool({ + name: 'bash', + description: + 'Executes bash shell commands in a persistent session. Supports execute and restart modes. ' + + 'Commands persist state (variables, directory) within the session. Node.js only.', + inputSchema: bashInputSchema, + callback: async (input, context) => { + if (!context) { + throw new Error('Tool context is required for bash operations') + } + + const agent = context.agent + + // Validate execute mode has command + if (input.mode === 'execute' && !input.command) { + throw new Error('command is required when mode is "execute"') + } + + // Handle restart mode + if (input.mode === 'restart') { + const existingSession = sessions.get(agent) + if (existingSession) { + existingSession.stop() + sessions.delete(agent) + } + // Create new session (will be added to activeSessions when started) + const newSession = new BashSession(120) + sessions.set(agent, newSession) + sessionFinalizer.register(agent, newSession) + return 'Bash session restarted' + } + + // Handle execute mode + // Get or create session + let session = sessions.get(agent) + if (!session) { + session = new BashSession(input.timeout ?? 120) + sessions.set(agent, session) + sessionFinalizer.register(agent, session) + } + + // Execute command + const result = await session.run(input.command!, input.timeout) + return result + }, +}) diff --git a/strands-ts/src/vended-tools/bash/index.ts b/strands-ts/src/vended-tools/bash/index.ts new file mode 100644 index 0000000000..87c6794775 --- /dev/null +++ b/strands-ts/src/vended-tools/bash/index.ts @@ -0,0 +1,7 @@ +/** + * Bash tool for executing shell commands in Node.js environments. + */ + +export { bash } from './bash.js' +export type { BashInput, BashOutput, ExecuteInput, RestartInput } from './types.js' +export { BashTimeoutError, BashSessionError } from './types.js' diff --git a/strands-ts/src/vended-tools/bash/types.ts b/strands-ts/src/vended-tools/bash/types.ts new file mode 100644 index 0000000000..6d72ef1e92 --- /dev/null +++ b/strands-ts/src/vended-tools/bash/types.ts @@ -0,0 +1,80 @@ +/** + * Type definitions for the bash tool. + */ + +/** + * Input parameters for execute operation. + */ +export interface ExecuteInput { + /** + * Operation mode, must be 'execute'. + */ + mode: 'execute' + + /** + * The bash command to execute. + */ + command: string + + /** + * Timeout in seconds for the command execution. + * Defaults to 120 seconds. + */ + timeout?: number +} + +/** + * Input parameters for restart operation. + */ +export interface RestartInput { + /** + * Operation mode, must be 'restart'. + */ + mode: 'restart' +} + +/** + * Union type of all valid bash tool inputs. + */ +export type BashInput = ExecuteInput | RestartInput + +/** + * Output format for bash command execution. + */ +export interface BashOutput { + /** + * Standard output from the command. + */ + output: string + + /** + * Standard error from the command. + * Empty string if no errors occurred. + */ + error: string + + /** + * Allow indexing with string keys for JSONValue compatibility. + */ + [key: string]: string +} + +/** + * Error thrown when a bash command exceeds its timeout. + */ +export class BashTimeoutError extends Error { + constructor(message: string) { + super(message) + this.name = 'BashTimeoutError' + } +} + +/** + * Error thrown when a bash session encounters an error. + */ +export class BashSessionError extends Error { + constructor(message: string) { + super(message) + this.name = 'BashSessionError' + } +} diff --git a/strands-ts/src/vended-tools/file-editor/README.md b/strands-ts/src/vended-tools/file-editor/README.md new file mode 100644 index 0000000000..52b4f8775b --- /dev/null +++ b/strands-ts/src/vended-tools/file-editor/README.md @@ -0,0 +1,97 @@ +# File Editor Tool + +A filesystem editor tool for viewing, creating, and editing files programmatically. Provides string replacement, line insertion, and directory viewing with security validation. + +## Features + +- **View files** with line numbers and optional line range support +- **Create files** with initial content +- **String-based find and replace** with uniqueness validation +- **Line-based text insertion** at any position +- **Directory viewing** up to 2 levels deep (configurable) +- **Configurable file size limits** (default 1MB) + +## Installation + +```typescript +import { fileEditor } from '@strands-agents/sdk/vended-tools/file-editor' +import { Agent, BedrockModel } from '@strands-agents/sdk' + +const agent = new Agent({ + model: new BedrockModel({ region: 'us-east-1' }), + tools: [fileEditor], +}) + +await agent.invoke('Create a file /tmp/notes.txt with "# My Notes"') +``` + +## Commands + +### `view` + +View file contents with line numbers or list directory contents (up to 2 levels deep). + +**Parameters:** + +- `path` (string, required): Absolute path to file or directory +- `view_range` (optional): `[start_line, end_line]` (1-indexed, end can be -1 for EOF) + +### `create` + +Create a new file with content. Creates parent directories if needed. + +**Parameters:** + +- `path` (string, required): Absolute path for new file +- `file_text` (string, required): Initial content + +### `str_replace` + +Replace an exact string match in a file. The string must appear exactly once. + +**Parameters:** + +- `path` (string, required): Absolute path to file +- `old_str` (string, required): Exact string to find +- `new_str` (string, optional): Replacement string + +### `insert` + +Insert text at a specific line number (0-indexed). + +**Parameters:** + +- `path` (string, required): Absolute path to file +- `insert_line` (number, required): Line number for insertion (0 = beginning) +- `new_str` (string, required): Text to insert + +## Example Usage + +```typescript +import { fileEditor } from '@strands-agents/sdk/vended-tools/file-editor' +import { Agent, BedrockModel } from '@strands-agents/sdk' + +const agent = new Agent({ + model: new BedrockModel({ region: 'us-east-1' }), + tools: [fileEditor], +}) + +// Agent can use natural language +await agent.invoke('Create /tmp/config.json with {"debug": false}') +await agent.invoke('Replace "debug": false with "debug": true in /tmp/config.json') +await agent.invoke('View lines 1-20 of /tmp/config.json') +``` + +## Security + +- Requires absolute paths (must start with `/`) +- Blocks directory traversal attempts (`..`) +- File size limits (default 1MB) +- Clear error messages + +## Limitations + +- Node.js only (uses filesystem APIs) +- Text files only (UTF-8 encoded) +- Exact string matching (no regex) +- History is session-scoped diff --git a/strands-ts/src/vended-tools/file-editor/__tests__/file-editor.test.node.ts b/strands-ts/src/vended-tools/file-editor/__tests__/file-editor.test.node.ts new file mode 100644 index 0000000000..4cc85700e3 --- /dev/null +++ b/strands-ts/src/vended-tools/file-editor/__tests__/file-editor.test.node.ts @@ -0,0 +1,505 @@ +import { describe, it, expect, beforeEach, afterEach } from 'vitest' +import { fileEditor } from '../file-editor.js' +import type { ToolContext } from '../../../index.js' +import { StateStore } from '../../../state-store.js' +import { createMockAgent } from '../../../__fixtures__/agent-helpers.js' +import { promises as fs } from 'fs' +import * as path from 'path' +import { tmpdir } from 'os' + +describe('fileEditor tool', () => { + let testDir: string + let context: ToolContext + + // Helper to create fresh state and context for each test + const createFreshContext = (): { state: StateStore; context: ToolContext } => { + const agent = createMockAgent() + const toolContext: ToolContext = { + toolUse: { + name: 'fileEditor', + toolUseId: 'test-id', + input: {}, + }, + agent, + invocationState: {}, + interrupt: () => { + throw new Error('interrupt not available in mock context') + }, + } + return { state: agent.appState, context: toolContext } + } + + // Helper to create a test file + const createTestFile = async (filename: string, content: string): Promise => { + const filePath = path.join(testDir, filename) + const dir = path.dirname(filePath) + await fs.mkdir(dir, { recursive: true }) + await fs.writeFile(filePath, content, 'utf-8') + return filePath + } + + // Helper to create a test directory with files + const createTestDirectory = async (dirName: string, files: Record): Promise => { + const dirPath = path.join(testDir, dirName) + await fs.mkdir(dirPath, { recursive: true }) + for (const [filename, content] of Object.entries(files)) { + const filePath = path.join(dirPath, filename) + const fileDir = path.dirname(filePath) + await fs.mkdir(fileDir, { recursive: true }) + await fs.writeFile(filePath, content, 'utf-8') + } + return dirPath + } + + beforeEach(async () => { + // Create a temporary test directory + testDir = path.join(tmpdir(), `file-editor-test-${Date.now()}-${Math.random().toString(36).slice(2)}`) + await fs.mkdir(testDir, { recursive: true }) + + // Create fresh state and context + const fresh = createFreshContext() + context = fresh.context + }) + + afterEach(async () => { + // Clean up test directory + try { + await fs.rm(testDir, { recursive: true, force: true }) + } catch { + // Ignore cleanup errors + } + }) + + describe('view command', () => { + describe('when viewing entire file', () => { + it('returns file content with line numbers', async () => { + const filePath = await createTestFile('test.txt', 'Line 1\nLine 2\nLine 3') + const result = await fileEditor.invoke({ command: 'view', path: filePath }, context) + expect(result).toContain("Here's the result of running `cat -n`") + expect(result).toContain(' 1 Line 1') + expect(result).toContain(' 2 Line 2') + expect(result).toContain(' 3 Line 3') + }) + + it('handles empty file', async () => { + const filePath = await createTestFile('empty.txt', '') + const result = await fileEditor.invoke({ command: 'view', path: filePath }, context) + expect(result).toContain("Here's the result of running `cat -n`") + }) + + it('handles single line file', async () => { + const filePath = await createTestFile('single.txt', 'Only one line') + const result = await fileEditor.invoke({ command: 'view', path: filePath }, context) + expect(result).toContain(' 1 Only one line') + }) + }) + + describe('when viewing with line range', () => { + it('returns specified lines with line numbers', async () => { + const filePath = await createTestFile('test.txt', 'Line 1\nLine 2\nLine 3\nLine 4\nLine 5') + const result = await fileEditor.invoke({ command: 'view', path: filePath, view_range: [2, 4] }, context) + expect(result).toContain(' 2 Line 2') + expect(result).toContain(' 3 Line 3') + expect(result).toContain(' 4 Line 4') + expect(result).not.toContain(' 1 ') + expect(result).not.toContain(' 5 ') + }) + + it('handles negative end index (-1 means to end)', async () => { + const filePath = await createTestFile('test.txt', 'Line 1\nLine 2\nLine 3\nLine 4\nLine 5') + const result = await fileEditor.invoke({ command: 'view', path: filePath, view_range: [3, -1] }, context) + expect(result).toContain(' 3 Line 3') + expect(result).toContain(' 4 Line 4') + expect(result).toContain(' 5 Line 5') + expect(result).not.toContain(' 1 ') + expect(result).not.toContain(' 2 ') + }) + + it('handles single line range', async () => { + const filePath = await createTestFile('test.txt', 'Line 1\nLine 2\nLine 3') + const result = await fileEditor.invoke({ command: 'view', path: filePath, view_range: [2, 2] }, context) + expect(result).toContain(' 2 Line 2') + expect(result).not.toContain(' 1 ') + expect(result).not.toContain(' 3 ') + }) + }) + + describe('when viewing directory', () => { + it('lists files up to 2 levels deep', async () => { + const dirPath = await createTestDirectory('testdir', { + 'file1.txt': 'content', + 'file2.txt': 'content', + 'subdir/file3.txt': 'content', + 'subdir/nested/file4.txt': 'content', + }) + const result = await fileEditor.invoke({ command: 'view', path: dirPath }, context) + expect(result).toContain('file1.txt') + expect(result).toContain('file2.txt') + expect(result).toContain('subdir') + expect(result).toContain('file3.txt') + expect(result).toContain('file4.txt') + }) + + it('excludes hidden files', async () => { + const dirPath = await createTestDirectory('testdir', { + 'visible.txt': 'content', + '.hidden.txt': 'content', + 'subdir/.hidden-dir/file.txt': 'content', + }) + const result = await fileEditor.invoke({ command: 'view', path: dirPath }, context) + expect(result).toContain('visible.txt') + expect(result).not.toContain('.hidden') + }) + }) + + describe('error cases', () => { + it('throws when file not found', async () => { + const nonExistentPath = path.join(testDir, 'nonexistent.txt') + await expect(fileEditor.invoke({ command: 'view', path: nonExistentPath }, context)).rejects.toThrow( + 'does not exist' + ) + }) + + it('throws when path is not absolute', async () => { + await expect(fileEditor.invoke({ command: 'view', path: 'relative/path.txt' }, context)).rejects.toThrow( + 'not an absolute path' + ) + }) + + it('throws when view_range has invalid start line', async () => { + const filePath = await createTestFile('test.txt', 'Line 1\nLine 2\nLine 3') + await expect( + fileEditor.invoke({ command: 'view', path: filePath, view_range: [0, 2] }, context) + ).rejects.toThrow('view_range') + }) + + it('throws when view_range end is beyond file length', async () => { + const filePath = await createTestFile('test.txt', 'Line 1\nLine 2\nLine 3') + await expect( + fileEditor.invoke({ command: 'view', path: filePath, view_range: [1, 10] }, context) + ).rejects.toThrow('view_range') + }) + + it('throws when view_range end is before start', async () => { + const filePath = await createTestFile('test.txt', 'Line 1\nLine 2\nLine 3') + await expect( + fileEditor.invoke({ command: 'view', path: filePath, view_range: [3, 1] }, context) + ).rejects.toThrow('view_range') + }) + + it('throws when view_range is provided for directory', async () => { + const dirPath = await createTestDirectory('testdir', { 'file.txt': 'content' }) + await expect( + fileEditor.invoke({ command: 'view', path: dirPath, view_range: [1, 2] }, context) + ).rejects.toThrow('not allowed when') + }) + }) + }) + + describe('create command', () => { + it('creates new file with content', async () => { + const filePath = path.join(testDir, 'new-file.txt') + const content = 'Hello World\nLine 2' + const result = await fileEditor.invoke({ command: 'create', path: filePath, file_text: content }, context) + expect(result).toContain('File created successfully') + expect(result).toContain(filePath) + + // Verify file was created + const fileContent = await fs.readFile(filePath, 'utf-8') + expect(fileContent).toBe(content) + }) + + it('creates file in non-existent directory', async () => { + const filePath = path.join(testDir, 'newdir', 'subdir', 'new-file.txt') + const content = 'Content' + const result = await fileEditor.invoke({ command: 'create', path: filePath, file_text: content }, context) + expect(result).toContain('File created successfully') + + // Verify file and directories were created + const fileContent = await fs.readFile(filePath, 'utf-8') + expect(fileContent).toBe(content) + }) + + it('creates empty file', async () => { + const filePath = path.join(testDir, 'empty.txt') + const result = await fileEditor.invoke({ command: 'create', path: filePath, file_text: '' }, context) + expect(result).toContain('File created successfully') + + const fileContent = await fs.readFile(filePath, 'utf-8') + expect(fileContent).toBe('') + }) + + describe('error cases', () => { + it('throws when file already exists', async () => { + const filePath = await createTestFile('existing.txt', 'content') + await expect( + fileEditor.invoke({ command: 'create', path: filePath, file_text: 'new content' }, context) + ).rejects.toThrow('already exists') + }) + + it('throws when path is not absolute', async () => { + await expect( + fileEditor.invoke({ command: 'create', path: 'relative/path.txt', file_text: 'content' }, context) + ).rejects.toThrow('not an absolute path') + }) + + it('throws when path contains traversal', async () => { + const filePath = '..outside.txt' + await expect( + fileEditor.invoke({ command: 'create', path: filePath, file_text: 'content' }, context) + ).rejects.toThrow() + }) + + it('throws when trying to create in directory as path', async () => { + const dirPath = await createTestDirectory('testdir', {}) + await expect( + fileEditor.invoke({ command: 'create', path: dirPath, file_text: 'content' }, context) + ).rejects.toThrow('already exists') + }) + }) + }) + + describe('str_replace command', () => { + it('replaces unique string occurrence', async () => { + const filePath = await createTestFile('test.txt', 'Line 1\nLine 2 OLD\nLine 3\nLine 4') + const result = await fileEditor.invoke( + { command: 'str_replace', path: filePath, old_str: 'OLD', new_str: 'NEW' }, + context + ) + expect(result).toContain('The file') + expect(result).toContain('has been edited') + expect(result).toContain('NEW') + + // Verify file was updated + const fileContent = await fs.readFile(filePath, 'utf-8') + expect(fileContent).toBe('Line 1\nLine 2 NEW\nLine 3\nLine 4') + }) + + it('shows snippet with 4 lines before and after change', async () => { + const content = 'Line 1\nLine 2\nLine 3\nLine 4\nLine 5 OLD\nLine 6\nLine 7\nLine 8\nLine 9\nLine 10' + const filePath = await createTestFile('test.txt', content) + const result = await fileEditor.invoke( + { command: 'str_replace', path: filePath, old_str: 'OLD', new_str: 'NEW' }, + context + ) + // Should show lines 1-9 (4 before + line 5 + 4 after) + expect(result).toContain('Line 1') + expect(result).toContain('Line 9') + expect(result).not.toContain('Line 10') + }) + + it('handles empty new_str (deletion)', async () => { + const filePath = await createTestFile('test.txt', 'Line 1\nLine 2 DELETE_ME\nLine 3') + const result = await fileEditor.invoke( + { command: 'str_replace', path: filePath, old_str: ' DELETE_ME', new_str: '' }, + context + ) + expect(result).toContain('has been edited') + + const fileContent = await fs.readFile(filePath, 'utf-8') + expect(fileContent).toBe('Line 1\nLine 2\nLine 3') + }) + + it('handles multi-line old_str', async () => { + const filePath = await createTestFile('test.txt', 'Line 1\nOLD LINE 1\nOLD LINE 2\nLine 4') + const result = await fileEditor.invoke( + { command: 'str_replace', path: filePath, old_str: 'OLD LINE 1\nOLD LINE 2', new_str: 'NEW LINE' }, + context + ) + expect(result).toContain('has been edited') + + const fileContent = await fs.readFile(filePath, 'utf-8') + expect(fileContent).toBe('Line 1\nNEW LINE\nLine 4') + }) + + it('preserves dollar sign patterns in new_str literally', async () => { + const filePath = await createTestFile('test.txt', 'const value = getPrice()') + await fileEditor.invoke( + { command: 'str_replace', path: filePath, old_str: 'getPrice()', new_str: '$& is not $1 or $$' }, + context + ) + + const fileContent = await fs.readFile(filePath, 'utf-8') + expect(fileContent).toBe('const value = $& is not $1 or $$') + }) + + describe('error cases', () => { + it('throws when old_str not found', async () => { + const filePath = await createTestFile('test.txt', 'Line 1\nLine 2\nLine 3') + await expect( + fileEditor.invoke({ command: 'str_replace', path: filePath, old_str: 'NOTFOUND', new_str: 'NEW' }, context) + ).rejects.toThrow('did not appear') + }) + + it('throws when multiple occurrences of old_str', async () => { + const filePath = await createTestFile('test.txt', 'DUP Line 1\nLine 2\nDUP Line 3') + await expect( + fileEditor.invoke({ command: 'str_replace', path: filePath, old_str: 'DUP', new_str: 'NEW' }, context) + ).rejects.toThrow('Multiple occurrences') + }) + + it('throws when file not found', async () => { + const nonExistentPath = path.join(testDir, 'nonexistent.txt') + await expect( + fileEditor.invoke({ command: 'str_replace', path: nonExistentPath, old_str: 'OLD', new_str: 'NEW' }, context) + ).rejects.toThrow('does not exist') + }) + + it('throws when path is directory', async () => { + const dirPath = await createTestDirectory('testdir', {}) + await expect( + fileEditor.invoke({ command: 'str_replace', path: dirPath, old_str: 'OLD', new_str: 'NEW' }, context) + ).rejects.toThrow('directory') + }) + }) + }) + + describe('insert command', () => { + it('inserts at beginning (line 0)', async () => { + const filePath = await createTestFile('test.txt', 'Line 1\nLine 2\nLine 3') + const result = await fileEditor.invoke( + { command: 'insert', path: filePath, insert_line: 0, new_str: 'NEW LINE' }, + context + ) + expect(result).toContain('has been edited') + + const fileContent = await fs.readFile(filePath, 'utf-8') + expect(fileContent).toBe('NEW LINE\nLine 1\nLine 2\nLine 3') + }) + + it('inserts in middle', async () => { + const filePath = await createTestFile('test.txt', 'Line 1\nLine 2\nLine 3') + const result = await fileEditor.invoke( + { command: 'insert', path: filePath, insert_line: 2, new_str: 'NEW LINE' }, + context + ) + expect(result).toContain('has been edited') + + const fileContent = await fs.readFile(filePath, 'utf-8') + expect(fileContent).toBe('Line 1\nLine 2\nNEW LINE\nLine 3') + }) + + it('inserts at end', async () => { + const filePath = await createTestFile('test.txt', 'Line 1\nLine 2\nLine 3') + const result = await fileEditor.invoke( + { command: 'insert', path: filePath, insert_line: 3, new_str: 'NEW LINE' }, + context + ) + expect(result).toContain('has been edited') + + const fileContent = await fs.readFile(filePath, 'utf-8') + expect(fileContent).toBe('Line 1\nLine 2\nLine 3\nNEW LINE') + }) + + it('shows snippet with 4 lines before and after insertion', async () => { + const content = 'Line 1\nLine 2\nLine 3\nLine 4\nLine 5\nLine 6\nLine 7\nLine 8\nLine 9' + const filePath = await createTestFile('test.txt', content) + const result = await fileEditor.invoke( + { command: 'insert', path: filePath, insert_line: 5, new_str: 'INSERTED' }, + context + ) + // Inserting at line 5 (0-indexed) means after Line 5 + // Snippet shows 4 lines before (lines 2-5) + inserted + 4 lines after (lines 6-9) + expect(result).toContain('Line 2') + expect(result).toContain('Line 9') + expect(result).toContain('INSERTED') + }) + + it('handles multi-line insertion', async () => { + const filePath = await createTestFile('test.txt', 'Line 1\nLine 2') + const result = await fileEditor.invoke( + { command: 'insert', path: filePath, insert_line: 1, new_str: 'NEW 1\nNEW 2\nNEW 3' }, + context + ) + expect(result).toContain('has been edited') + + const fileContent = await fs.readFile(filePath, 'utf-8') + expect(fileContent).toBe('Line 1\nNEW 1\nNEW 2\nNEW 3\nLine 2') + }) + + it('handles insertion in empty file', async () => { + const filePath = await createTestFile('empty.txt', '') + const result = await fileEditor.invoke( + { command: 'insert', path: filePath, insert_line: 0, new_str: 'First line' }, + context + ) + expect(result).toContain('has been edited') + + const fileContent = await fs.readFile(filePath, 'utf-8') + expect(fileContent).toBe('First line') + }) + + describe('error cases', () => { + it('throws when insert_line is negative', async () => { + const filePath = await createTestFile('test.txt', 'Line 1\nLine 2') + await expect( + fileEditor.invoke({ command: 'insert', path: filePath, insert_line: -1, new_str: 'NEW' }, context) + ).rejects.toThrow('insert_line') + }) + + it('throws when insert_line is beyond file length', async () => { + const filePath = await createTestFile('test.txt', 'Line 1\nLine 2') + await expect( + fileEditor.invoke({ command: 'insert', path: filePath, insert_line: 10, new_str: 'NEW' }, context) + ).rejects.toThrow('insert_line') + }) + + it('throws when file not found', async () => { + const nonExistentPath = path.join(testDir, 'nonexistent.txt') + await expect( + fileEditor.invoke({ command: 'insert', path: nonExistentPath, insert_line: 0, new_str: 'NEW' }, context) + ).rejects.toThrow('does not exist') + }) + + it('throws when path is directory', async () => { + const dirPath = await createTestDirectory('testdir', {}) + await expect( + fileEditor.invoke({ command: 'insert', path: dirPath, insert_line: 0, new_str: 'NEW' }, context) + ).rejects.toThrow('directory') + }) + }) + }) + + describe('path validation and security', () => { + it('rejects relative paths', async () => { + await expect(fileEditor.invoke({ command: 'view', path: 'relative/path.txt' }, context)).rejects.toThrow( + 'not an absolute path' + ) + }) + }) + + describe('file size limits', () => { + it('throws when file exceeds default size limit', async () => { + // Create a file larger than 1MB + const largeContent = 'x'.repeat(1048577) // 1MB + 1 byte + const filePath = await createTestFile('large.txt', largeContent) + + await expect(fileEditor.invoke({ command: 'view', path: filePath }, context)).rejects.toThrow('exceeds') + }) + }) + + describe('edge cases', () => { + it('handles files with special characters in content', async () => { + const content = 'Special chars: @#$%^&*()_+-={}[]|:;"<>,.?/~`' + const filePath = await createTestFile('special.txt', content) + const result = await fileEditor.invoke({ command: 'view', path: filePath }, context) + expect(result).toContain('Special chars:') + }) + + it('handles files with unicode characters', async () => { + const content = '你好世界\n🚀 Emoji test\nΣ Greek letters' + const filePath = await createTestFile('unicode.txt', content) + const result = await fileEditor.invoke({ command: 'view', path: filePath }, context) + expect(result).toContain('你好世界') + expect(result).toContain('🚀') + }) + + it('handles files with tabs (expands tabs)', async () => { + const content = 'Line 1\tTab\tSeparated' + const filePath = await createTestFile('tabs.txt', content) + const result = await fileEditor.invoke({ command: 'view', path: filePath }, context) + // Tabs should be expanded to spaces + expect(result).not.toContain('\t') + }) + }) +}) diff --git a/strands-ts/src/vended-tools/file-editor/file-editor.ts b/strands-ts/src/vended-tools/file-editor/file-editor.ts new file mode 100644 index 0000000000..e8fea542cc --- /dev/null +++ b/strands-ts/src/vended-tools/file-editor/file-editor.ts @@ -0,0 +1,454 @@ +import { tool } from '../../tools/tool-factory.js' +import { z } from 'zod' +import type { IFileReader } from './types.js' +import { promises as fs } from 'fs' +import * as path from 'path' + +const SNIPPET_LINES = 4 +const DEFAULT_MAX_FILE_SIZE = 1048576 // 1MB +const MAX_DIRECTORY_DEPTH = 2 + +/** + * Zod schema for file editor input validation. + */ +const fileEditorInputSchema = z.object({ + command: z + .enum(['view', 'create', 'str_replace', 'insert']) + .describe('The operation to perform: `view`, `create`, `str_replace`, `insert`.'), + path: z.string().describe('Absolute path to the file or directory.'), + file_text: z.string().optional().describe('Content for new file (required for create command).'), + view_range: z + .tuple([z.number(), z.number()]) + .optional() + .describe('Line range to view [start, end]. 1-indexed. End can be -1 for end of file.'), + old_str: z.string().optional().describe('Exact string to find and replace (required for str_replace command).'), + new_str: z.string().optional().describe('Replacement string (for str_replace and insert commands).'), + insert_line: z + .number() + .optional() + .describe('Line number where text should be inserted (0-indexed, required for insert command).'), +}) + +/** + * Text file reader implementation. + * Reads files as UTF-8 encoded text. + */ +class TextFileReader implements IFileReader { + async read(filePath: string): Promise { + return await fs.readFile(filePath, 'utf-8') + } +} + +/** + * File editor tool for viewing, creating, and editing files programmatically. + * + * Provides commands for viewing files/directories, creating files, string replacement, + * and line insertion. + * + * @example + * ```typescript + * import { fileEditor } from '@strands-agents/sdk/vended-tools/file-editor' + * import { Agent } from '@strands-agents/sdk' + * + * const agent = new Agent({ + * model: new BedrockModel({ region: 'us-east-1' }), + * tools: [fileEditor], + * }) + * + * await agent.invoke('View the file /tmp/test.txt') + * await agent.invoke('Create a file /tmp/notes.txt with content "Hello World"') + * await agent.invoke('Replace "Hello" with "Hi" in /tmp/notes.txt') + * ``` + */ +export const fileEditor = tool({ + name: 'fileEditor', + description: + 'Filesystem editor tool for viewing, creating, and editing files. Supports view (with line ranges), create, str_replace, and insert operations. Files must use absolute paths.', + inputSchema: fileEditorInputSchema, + callback: async (input, context) => { + if (!context) { + throw new Error('Tool context is required for file editor operations') + } + + const fileReader = new TextFileReader() + + let result: string + + switch (input.command) { + case 'view': + result = await handleView(input.path, input.view_range, fileReader) + break + + case 'create': + result = await handleCreate(input.path, input.file_text!) + break + + case 'str_replace': + result = await handleStrReplace(input.path, input.old_str!, input.new_str, fileReader) + break + + case 'insert': + result = await handleInsert(input.path, input.insert_line!, input.new_str!, fileReader) + break + + default: + throw new Error(`Unknown command: ${input.command}`) + } + + return result + }, +}) + +/** + * Validates that a path is absolute and doesn't contain directory traversal. + */ +function validatePath(command: string, filePath: string): void { + // Check if it's an absolute path + if (!path.isAbsolute(filePath)) { + const suggestedPath = path.resolve(filePath) + throw new Error( + `The path ${filePath} is not an absolute path, it should start with \`/\`. Maybe you meant ${suggestedPath}?` + ) + } + + // Check for directory traversal - reject paths containing '..' segments + const normalized = path.normalize(filePath) + if (normalized.includes('..')) { + throw new Error(`Invalid path: path traversal is not allowed`) + } +} + +/** + * Checks if a file exists. + */ +async function fileExists(filePath: string): Promise { + try { + await fs.access(filePath) + return true + } catch { + return false + } +} + +/** + * Checks if a path is a directory. + */ +async function isDirectory(filePath: string): Promise { + try { + const stats = await fs.stat(filePath) + return stats.isDirectory() + } catch { + return false + } +} + +/** + * Checks file size against limit. + */ +async function checkFileSize(filePath: string, maxSize: number = DEFAULT_MAX_FILE_SIZE): Promise { + const stats = await fs.stat(filePath).catch((err) => { + throw new Error(`Failed to check file size: ${err}`) + }) + + if (stats.size > maxSize) { + throw new Error(`File size (${stats.size} bytes) exceeds maximum allowed size (${maxSize} bytes)`) + } +} + +/** + * Formats file content with line numbers (cat -n style). + */ +function makeOutput(fileContent: string, fileDescriptor: string, initLine: number = 1): string { + // Expand tabs to spaces in content + const expandedContent = fileContent.replace(/\t/g, ' ') + + const numberedLines = expandedContent.split('\n').map((line, index) => { + const lineNum = index + initLine + // Use two spaces instead of tab to avoid any tabs in output + return `${lineNum.toString().padStart(6)} ${line}` + }) + + return `Here's the result of running \`cat -n\` on ${fileDescriptor}:\n${numberedLines.join('\n')}\n` +} + +/** + * Lists directory contents up to 2 levels deep, excluding hidden files. + */ +async function listDirectory(dirPath: string): Promise { + const items: string[] = [] + + async function walk(currentPath: string, depth: number): Promise { + try { + const entries = await fs.readdir(currentPath, { withFileTypes: true }) + + for (const entry of entries) { + // Skip hidden files/directories + if (entry.name.startsWith('.')) continue + + const fullPath = path.join(currentPath, entry.name) + const relativePath = path.relative(dirPath, fullPath) + items.push(relativePath || entry.name) + + // Continue walking if we haven't reached max depth yet + if (entry.isDirectory() && depth < MAX_DIRECTORY_DEPTH) { + await walk(fullPath, depth + 1) + } + } + } catch { + // Ignore permission errors and continue + } + } + + await walk(dirPath, 0) + + const result = items.sort().join('\n') + return `Here's the files and directories up to 2 levels deep in ${dirPath}, excluding hidden items:\n${result}\n` +} + +/** + * Handles the view command. + */ +async function handleView( + filePath: string, + viewRange: [number, number] | undefined, + fileReader: IFileReader +): Promise { + validatePath('view', filePath) + + const exists = await fileExists(filePath) + if (!exists) { + throw new Error(`The path ${filePath} does not exist. Please provide a valid path.`) + } + + const isDir = await isDirectory(filePath) + + if (isDir) { + if (viewRange) { + throw new Error('The `view_range` parameter is not allowed when `path` points to a directory.') + } + return await listDirectory(filePath) + } + + // Check file size before reading + await checkFileSize(filePath) + + // Read file content - only if not a directory + const fileContent = await fileReader.read(filePath) + + let initLine = 1 + let contentToShow = fileContent + + if (viewRange) { + const lines = fileContent.split('\n') + const nLines = lines.length + let [start, end] = viewRange + + // Validate range + if (start < 1 || start > nLines) { + throw new Error( + `Invalid \`view_range\`: [${start}, ${end}]. Its first element \`${start}\` should be within the range of lines of the file: [1, ${nLines}]` + ) + } + + if (end !== -1 && end > nLines) { + throw new Error( + `Invalid \`view_range\`: [${start}, ${end}]. Its second element \`${end}\` should be smaller than the number of lines in the file: \`${nLines}\`` + ) + } + + if (end !== -1 && end < start) { + throw new Error( + `Invalid \`view_range\`: [${start}, ${end}]. Its second element \`${end}\` should be larger or equal than its first \`${start}\`` + ) + } + + initLine = start + if (end === -1) { + contentToShow = lines.slice(start - 1).join('\n') + } else { + contentToShow = lines.slice(start - 1, end).join('\n') + } + } + + return makeOutput(contentToShow, filePath, initLine) +} + +/** + * Handles the create command. + */ +async function handleCreate(filePath: string, fileText: string): Promise { + if (fileText === undefined) { + throw new Error('Parameter `file_text` is required for command: create') + } + + validatePath('create', filePath) + + const exists = await fileExists(filePath) + if (exists) { + throw new Error(`File already exists at: ${filePath}. Cannot overwrite files using command \`create\`.`) + } + + // Create parent directories if needed + const dir = path.dirname(filePath) + await fs.mkdir(dir, { recursive: true }) + + // Write file + await fs.writeFile(filePath, fileText, 'utf-8') + + return `File created successfully at: ${filePath}` +} + +/** + * Handles the str_replace command. + */ +async function handleStrReplace( + filePath: string, + oldStr: string, + newStr: string | undefined, + fileReader: IFileReader +): Promise { + if (oldStr === undefined) { + throw new Error('Parameter `old_str` is required for command: str_replace') + } + + validatePath('str_replace', filePath) + + const exists = await fileExists(filePath) + if (!exists) { + throw new Error(`The path ${filePath} does not exist. Please provide a valid path.`) + } + + const isDir = await isDirectory(filePath) + if (isDir) { + throw new Error(`The path ${filePath} is a directory and only the \`view\` command can be used on directories`) + } + + await checkFileSize(filePath) + + // Read file content + let fileContent = await fileReader.read(filePath) + + // Expand tabs in content and search string + fileContent = fileContent.replace(/\t/g, ' ') + const expandedOldStr = oldStr.replace(/\t/g, ' ') + const expandedNewStr = newStr ? newStr.replace(/\t/g, ' ') : '' + + // Check if old_str is unique + const occurrences = (fileContent.match(new RegExp(escapeRegExp(expandedOldStr), 'g')) || []).length + + if (occurrences === 0) { + throw new Error(`No replacement was performed, old_str \`${oldStr}\` did not appear verbatim in ${filePath}.`) + } + + if (occurrences > 1) { + const lines = fileContent.split('\n') + const lineNumbers = lines + .map((line, index) => (line.includes(expandedOldStr) ? index + 1 : -1)) + .filter((num) => num !== -1) + throw new Error( + `No replacement was performed. Multiple occurrences of old_str \`${oldStr}\` in lines ${JSON.stringify(lineNumbers)}. Please ensure it is unique` + ) + } + + // Perform replacement + const newFileContent = fileContent.replace(expandedOldStr, () => expandedNewStr) + + // Write back to file + await fs.writeFile(filePath, newFileContent, 'utf-8') + + // Create snippet + const replacementLine = fileContent.substring(0, fileContent.indexOf(expandedOldStr)).split('\n').length - 1 + const insertedLines = expandedNewStr.split('\n').length + const originalLines = expandedOldStr.split('\n').length + const lineDifference = insertedLines - originalLines + + const lines = newFileContent.split('\n') + const startLine = Math.max(0, replacementLine - SNIPPET_LINES) + const endLine = Math.min(lines.length, replacementLine + SNIPPET_LINES + lineDifference + 1) + const snippetLines = lines.slice(startLine, endLine) + const snippet = snippetLines.join('\n') + + const successMsg = `The file ${filePath} has been edited. ${makeOutput(snippet, `a snippet of ${filePath}`, startLine + 1)}Review the changes and make sure they are as expected. Edit the file again if necessary.` + + return successMsg +} + +/** + * Handles the insert command. + */ +async function handleInsert( + filePath: string, + insertLine: number, + newStr: string, + fileReader: IFileReader +): Promise { + if (insertLine === undefined || newStr === undefined) { + throw new Error('Parameters `insert_line` and `new_str` are required for command: insert') + } + + validatePath('insert', filePath) + + const exists = await fileExists(filePath) + if (!exists) { + throw new Error(`The path ${filePath} does not exist. Please provide a valid path.`) + } + + const isDir = await isDirectory(filePath) + if (isDir) { + throw new Error(`The path ${filePath} is a directory and only the \`view\` command can be used on directories`) + } + + await checkFileSize(filePath) + + // Read file content + let fileText = await fileReader.read(filePath) + + // Expand tabs + fileText = fileText.replace(/\t/g, ' ') + const expandedNewStr = newStr.replace(/\t/g, ' ') + + const fileTextLines = fileText.split('\n') + const nLines = fileTextLines.length + + // Validate insert_line + if (insertLine < 0 || insertLine > nLines) { + throw new Error( + `Invalid \`insert_line\` parameter: ${insertLine}. It should be within the range of lines of the file: [0, ${nLines}]` + ) + } + + // Perform insertion + const newStrLines = expandedNewStr.split('\n') + + // Handle empty file case + let newFileTextLines: string[] + if (fileText === '') { + newFileTextLines = newStrLines + } else { + newFileTextLines = [...fileTextLines.slice(0, insertLine), ...newStrLines, ...fileTextLines.slice(insertLine)] + } + + const newFileText = newFileTextLines.join('\n') + + // Write back to file + await fs.writeFile(filePath, newFileText, 'utf-8') + + // Create snippet - show lines around the insertion point + // Show 4 lines before the insertion line and 4 lines after + const snippetStartLine = Math.max(0, insertLine - SNIPPET_LINES) + const snippetEndLine = Math.min(newFileTextLines.length, insertLine + newStrLines.length + SNIPPET_LINES) + const snippetLines = newFileTextLines.slice(snippetStartLine, snippetEndLine) + const snippet = snippetLines.join('\n') + const startLine = snippetStartLine + 1 + + const successMsg = `The file ${filePath} has been edited. ${makeOutput(snippet, 'a snippet of the edited file', startLine)}Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary.` + + return successMsg +} + +/** + * Escapes special regex characters in a string. + */ +function escapeRegExp(string: string): string { + return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&') +} diff --git a/strands-ts/src/vended-tools/file-editor/index.ts b/strands-ts/src/vended-tools/file-editor/index.ts new file mode 100644 index 0000000000..bfcf0b0a11 --- /dev/null +++ b/strands-ts/src/vended-tools/file-editor/index.ts @@ -0,0 +1,6 @@ +/** + * File editor tool for programmatic filesystem interaction. + */ + +export { fileEditor } from './file-editor.js' +export type { FileEditorInput, FileEditorOptions, IFileReader } from './types.js' diff --git a/strands-ts/src/vended-tools/file-editor/types.ts b/strands-ts/src/vended-tools/file-editor/types.ts new file mode 100644 index 0000000000..2bd89a92c6 --- /dev/null +++ b/strands-ts/src/vended-tools/file-editor/types.ts @@ -0,0 +1,66 @@ +/** + * Configuration options for the file editor tool. + */ +export interface FileEditorOptions { + /** + * Maximum file size in bytes that can be read (default: 1048576 / 1MB). + */ + maxFileSize?: number +} + +/** + * Input parameters for view operation. + */ +export interface ViewInput { + command: 'view' + path: string + view_range?: [number, number] +} + +/** + * Input parameters for create operation. + */ +export interface CreateInput { + command: 'create' + path: string + file_text: string +} + +/** + * Input parameters for str_replace operation. + */ +export interface StrReplaceInput { + command: 'str_replace' + path: string + old_str: string + new_str?: string +} + +/** + * Input parameters for insert operation. + */ +export interface InsertInput { + command: 'insert' + path: string + insert_line: number + new_str: string +} + +/** + * Union type of all valid file editor inputs. + */ +export type FileEditorInput = ViewInput | CreateInput | StrReplaceInput | InsertInput + +/** + * Interface for pluggable file readers. + * Allows extending the file editor to support different file types. + */ +export interface IFileReader { + /** + * Reads the file content and returns it as a string. + * + * @param path - Absolute path to the file + * @returns File content as a string + */ + read(path: string): Promise +} diff --git a/strands-ts/src/vended-tools/http-request/README.md b/strands-ts/src/vended-tools/http-request/README.md new file mode 100644 index 0000000000..f2bf6c91a2 --- /dev/null +++ b/strands-ts/src/vended-tools/http-request/README.md @@ -0,0 +1,83 @@ +# HTTP Request Tool + +A cross-platform HTTP request tool for making HTTP requests to external APIs from Strands agents. + +## Features + +- **All HTTP Methods**: Supports GET, POST, PUT, DELETE, PATCH, HEAD, and OPTIONS +- **Cross-Platform**: Uses native `fetch` API - works in Node.js 20+ and all modern browsers +- **Timeout Support**: Configurable request timeout with default of 30 seconds +- **Type-Safe**: Full TypeScript support with Zod schema validation +- **Comprehensive Error Handling**: Network errors, timeouts, and HTTP errors are properly handled + +## Installation + +```bash +npm install @strands-agents/sdk +``` + +## Usage + +### With an Agent + +```typescript +import { Agent } from '@strands-agents/sdk' +import { httpRequest } from '@strands-agents/sdk/vended-tools/http-request' + +const agent = new Agent({ + tools: [httpRequest], +}) + +// Agent will use the tool based on your prompts +await agent.invoke('Get data from https://api.example.com/data') +``` + +### Direct Invocation + +```typescript +import { httpRequest } from '@strands-agents/sdk/vended-tools/http-request' + +// Simple GET request +const response = await httpRequest.invoke({ + method: 'GET', + url: 'https://api.example.com/data', +}) + +console.log(response.status) // 200 +console.log(response.body) // Response body as text +``` + +## API + +### Input + +The tool accepts an object with the following properties: + +| Property | Type | Required | Default | Description | +| --------- | ------------------------------------------------------------------------ | -------- | ------- | ------------------------------------ | +| `method` | `'GET' \| 'POST' \| 'PUT' \| 'DELETE' \| 'PATCH' \| 'HEAD' \| 'OPTIONS'` | Yes | - | HTTP method to use | +| `url` | `string` | Yes | - | URL to send the request to | +| `headers` | `Record` | No | - | Optional HTTP headers | +| `body` | `string` | No | - | Optional request body (for POST/PUT) | +| `timeout` | `number` | No | 30 | Timeout in seconds | + +### Output + +Returns an object with the following properties: + +| Property | Type | Description | +| ------------ | ------------------------ | -------------------------------- | +| `status` | `number` | HTTP status code | +| `statusText` | `string` | HTTP status text | +| `headers` | `Record` | Response headers as plain object | +| `body` | `string` | Response body as text | + +### Error Handling + +The tool throws standard JavaScript Error objects in the following cases: + +- **Timeout Error**: Request exceeds the specified timeout (error message includes "Request timed out") +- **HTTP Error**: HTTP response with non-2xx status code (error message includes HTTP status code and status text) +- **Network Errors**: Connection failures, DNS resolution failures, etc. + +When used within an agent, these errors are automatically converted to tool execution errors. diff --git a/strands-ts/src/vended-tools/http-request/__tests__/http-request.test.ts b/strands-ts/src/vended-tools/http-request/__tests__/http-request.test.ts new file mode 100644 index 0000000000..2b94b18433 --- /dev/null +++ b/strands-ts/src/vended-tools/http-request/__tests__/http-request.test.ts @@ -0,0 +1,235 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { httpRequest } from '../http-request.js' + +describe('httpRequest tool', () => { + const originalFetch = globalThis.fetch + + beforeEach(() => { + vi.clearAllMocks() + }) + + afterEach(() => { + globalThis.fetch = originalFetch + }) + + describe.each([ + { method: 'GET' as const, status: 200, statusText: 'OK' }, + { method: 'POST' as const, status: 201, statusText: 'Created' }, + { method: 'PUT' as const, status: 200, statusText: 'OK' }, + { method: 'DELETE' as const, status: 204, statusText: 'No Content' }, + { method: 'PATCH' as const, status: 200, statusText: 'OK' }, + { method: 'HEAD' as const, status: 200, statusText: 'OK' }, + { method: 'OPTIONS' as const, status: 200, statusText: 'OK' }, + ])('$method request', ({ method, status, statusText }) => { + it('returns successful response', async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + status, + statusText, + headers: new Map([['content-type', 'application/json']]), + text: async () => '{"success":true}', + }) + + const result = await httpRequest.invoke({ + method, + url: 'https://api.example.com/resource', + }) + + expect(result.status).toBe(status) + expect(result.statusText).toBe(statusText) + expect(result.headers['content-type']).toBe('application/json') + expect(result.body).toBe('{"success":true}') + }) + }) + + describe('request configuration', () => { + it('sends request with custom headers and body', async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + headers: new Map([]), + text: async () => '{"id":123}', + }) + + await httpRequest.invoke({ + method: 'POST', + url: 'https://api.example.com/users', + body: '{"name":"test"}', + headers: { 'Content-Type': 'application/json' }, + }) + + expect(globalThis.fetch).toHaveBeenCalledWith( + 'https://api.example.com/users', + expect.objectContaining({ + method: 'POST', + body: '{"name":"test"}', + headers: { 'Content-Type': 'application/json' }, + }) + ) + }) + }) + + describe('response handling', () => { + it('handles empty response body', async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + status: 204, + statusText: 'No Content', + headers: new Map([]), + text: async () => '', + }) + + const result = await httpRequest.invoke({ + method: 'DELETE', + url: 'https://api.example.com/resource', + }) + + expect(result.body).toBe('') + }) + + it('handles string response body', async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + headers: new Map([]), + text: async () => 'Plain text response', + }) + + const result = await httpRequest.invoke({ + method: 'GET', + url: 'https://api.example.com/text', + }) + + expect(result.body).toBe('Plain text response') + }) + + it('converts response headers to plain object', async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + headers: new Map([ + ['content-type', 'application/json'], + ['x-custom-header', 'value'], + ]), + text: async () => '{}', + }) + + const result = await httpRequest.invoke({ + method: 'GET', + url: 'https://api.example.com', + }) + + expect(result.headers).toEqual({ + 'content-type': 'application/json', + 'x-custom-header': 'value', + }) + }) + }) + + describe('HTTP status codes', () => { + it('succeeds for 2xx status codes', async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: true, + status: 201, + statusText: 'Created', + headers: new Map([]), + text: async () => 'created', + }) + + const result = await httpRequest.invoke({ + method: 'POST', + url: 'https://api.example.com', + }) + + expect(result.status).toBe(201) + }) + + it('throws error for 3xx redirect status codes', async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: false, + status: 301, + statusText: 'Moved Permanently', + headers: new Map([]), + text: async () => '', + }) + + await expect( + httpRequest.invoke({ + method: 'GET', + url: 'https://api.example.com/moved', + }) + ).rejects.toThrow('HTTP 301') + }) + + it('throws error for 4xx client error codes', async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: false, + status: 404, + statusText: 'Not Found', + headers: new Map([]), + text: async () => 'Not found', + }) + + await expect( + httpRequest.invoke({ + method: 'GET', + url: 'https://api.example.com/notfound', + }) + ).rejects.toThrow('HTTP 404 Not Found') + }) + + it('throws error for 5xx server error codes', async () => { + globalThis.fetch = vi.fn().mockResolvedValue({ + ok: false, + status: 500, + statusText: 'Internal Server Error', + headers: new Map([]), + text: async () => 'Server error', + }) + + await expect( + httpRequest.invoke({ + method: 'GET', + url: 'https://api.example.com/error', + }) + ).rejects.toThrow('HTTP 500') + }) + }) + + describe('error handling', () => { + it('throws timeout error when request exceeds timeout', async () => { + globalThis.fetch = vi.fn().mockImplementation( + async (_url, _options) => + new Promise((_resolve, reject) => { + globalThis.setTimeout(() => { + const error = new Error('The operation was aborted') + error.name = 'AbortError' + reject(error) + }, 100) + }) + ) + + await expect( + httpRequest.invoke({ + method: 'GET', + url: 'https://slow-api.example.com', + timeout: 0.1, + }) + ).rejects.toThrow('Request timed out') + }) + + it('throws error for network failures', async () => { + globalThis.fetch = vi.fn().mockRejectedValue(new Error('Network error: Failed to fetch')) + + await expect( + httpRequest.invoke({ + method: 'GET', + url: 'https://invalid-domain.com', + }) + ).rejects.toThrow('Network error: Failed to fetch') + }) + }) +}) diff --git a/strands-ts/src/vended-tools/http-request/http-request.ts b/strands-ts/src/vended-tools/http-request/http-request.ts new file mode 100644 index 0000000000..c0d01acd76 --- /dev/null +++ b/strands-ts/src/vended-tools/http-request/http-request.ts @@ -0,0 +1,87 @@ +import { tool } from '../../tools/tool-factory.js' +import { z } from 'zod' + +/** + * Zod schema for HTTP request input validation. + */ +const httpRequestInputSchema = z.object({ + method: z + .enum(['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD', 'OPTIONS']) + .describe('HTTP method to use for the request'), + url: z.string().url().describe('URL to send the request to'), + headers: z.record(z.string(), z.string()).optional().describe('Optional HTTP headers as key-value pairs'), + body: z.string().optional().describe('Optional request body as a string'), + timeout: z.number().positive().optional().describe('Optional timeout in seconds (default: 30)'), +}) + +/** + * HTTP request tool for making HTTP requests to external APIs. + * + * Supports all standard HTTP methods (GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS) + * and provides comprehensive request configuration including headers, body, and timeout. + * + * @example + * ```typescript + * // With agent + * const agent = new Agent({ tools: [httpRequest] }) + * await agent.invoke('Make a GET request to https://api.example.com/data') + * + * // Direct usage + * const response = await httpRequest.invoke({ + * method: 'POST', + * url: 'https://api.example.com/users', + * headers: { 'Content-Type': 'application/json' }, + * body: '{"name":"test"}', + * timeout: 10 + * }) + * ``` + */ +export const httpRequest = tool({ + name: 'http_request', + description: + 'Makes HTTP requests to external APIs. Supports GET, POST, PUT, DELETE, PATCH, HEAD, and OPTIONS methods. Returns response with status, headers, and body.', + inputSchema: httpRequestInputSchema, + callback: async (input, context) => { + const { method, url, headers, body, timeout = 30 } = input + + // Abort on timeout or agent cancellation, whichever comes first + const timeoutSignal = AbortSignal.timeout(timeout * 1000) + const signal = context ? AbortSignal.any([timeoutSignal, context.agent.cancelSignal]) : timeoutSignal + + try { + const fetchOptions: RequestInit = { method, signal } + + if (headers !== undefined) { + fetchOptions.headers = headers + } + if (body !== undefined) { + fetchOptions.body = body + } + + const response = await globalThis.fetch(url, fetchOptions) + const responseBody = await response.text() + + const responseHeaders: Record = {} + response.headers.forEach((value, key) => { + responseHeaders[key] = value + }) + + if (!response.ok) { + throw new Error(`HTTP ${response.status} ${response.statusText}: ${method} ${url}`) + } + + return { + status: response.status, + statusText: response.statusText, + headers: responseHeaders, + body: responseBody, + } + } catch (error) { + if (error instanceof Error && error.name === 'AbortError') { + const reason = timeoutSignal.aborted ? `timed out after ${timeout} seconds` : 'cancelled' + throw new Error(`Request ${reason}: ${method} ${url}`) + } + throw error + } + }, +}) diff --git a/strands-ts/src/vended-tools/http-request/index.ts b/strands-ts/src/vended-tools/http-request/index.ts new file mode 100644 index 0000000000..42f375bc5a --- /dev/null +++ b/strands-ts/src/vended-tools/http-request/index.ts @@ -0,0 +1,6 @@ +/** + * HTTP request tool for making HTTP requests to external APIs. + */ + +export { httpRequest } from './http-request.js' +export type { HttpRequestInput, HttpRequestOutput } from './types.js' diff --git a/strands-ts/src/vended-tools/http-request/types.ts b/strands-ts/src/vended-tools/http-request/types.ts new file mode 100644 index 0000000000..c95dee9e61 --- /dev/null +++ b/strands-ts/src/vended-tools/http-request/types.ts @@ -0,0 +1,54 @@ +/** + * Input parameters for HTTP request. + */ +export interface HttpRequestInput { + /** + * HTTP method to use for the request. + */ + method: 'GET' | 'POST' | 'PUT' | 'DELETE' | 'PATCH' | 'HEAD' | 'OPTIONS' + + /** + * URL to send the request to. + */ + url: string + + /** + * Optional HTTP headers as key-value pairs. + */ + headers?: Record + + /** + * Optional request body as a string. + */ + body?: string + + /** + * Optional timeout in seconds (default: 30). + */ + timeout?: number +} + +/** + * Output from HTTP request containing response details. + */ +export interface HttpRequestOutput { + /** + * HTTP status code. + */ + status: number + + /** + * HTTP status text. + */ + statusText: string + + /** + * Response headers as key-value pairs. + */ + headers: Record + + /** + * Response body as text. + */ + body: string +} diff --git a/strands-ts/src/vended-tools/index.ts b/strands-ts/src/vended-tools/index.ts new file mode 100644 index 0000000000..564be7da62 --- /dev/null +++ b/strands-ts/src/vended-tools/index.ts @@ -0,0 +1,17 @@ +/** + * Barrel export for all vended tools. + * + * Provides a single import path for consumers who want all built-in tools: + * ```typescript + * import { bash, fileEditor, httpRequest, notebook } from '@strands-agents/sdk/vended-tools' + * ``` + * + * Note: This module requires a Node.js environment because the `bash` tool + * imports `child_process`. For browser-compatible usage, import individual + * tools via their subpath exports (e.g., `@strands-agents/sdk/vended-tools/notebook`). + */ + +export * from './bash/index.js' +export * from './file-editor/index.js' +export * from './http-request/index.js' +export * from './notebook/index.js' diff --git a/strands-ts/src/vended-tools/notebook/README.md b/strands-ts/src/vended-tools/notebook/README.md new file mode 100644 index 0000000000..198143e01c --- /dev/null +++ b/strands-ts/src/vended-tools/notebook/README.md @@ -0,0 +1,182 @@ +# Notebook Tool + +A tool for managing persistent text notebooks within agent sessions. The notebook tool allows agents to create, read, write, list, and clear notebooks with automatic state persistence. + +## Installation + +```typescript +import { Agent, BedrockModel } from '@strands-agents/sdk' +import { notebook } from '@strands-agents/sdk/vended-tools/notebook' +``` + +## Quick Start + +### Creating an Agent with the Notebook Tool + +```typescript +import { Agent, BedrockModel } from '@strands-agents/sdk' +import { notebook } from '@strands-agents/sdk/vended-tools/notebook' + +// Create an agent with the notebook tool +const agent = new Agent({ + model: new BedrockModel({ + region: 'us-east-1', + }), + tools: [notebook], +}) + +// Use natural language to interact with notebooks +await agent.invoke('Create a notebook called "ideas" with the title "# Project Ideas"') +await agent.invoke('Add "- Build a web scraper" to the ideas notebook') +await agent.invoke('Add "- Create a CLI tool" to the ideas notebook') +await agent.invoke('Read the ideas notebook') +``` + +### State Persistence + +The notebook tool automatically persists state within an agent session: + +```typescript +// Notebooks persist across multiple invocations +await agent.invoke('Create a notebook called "todo" with "# Tasks"') +await agent.invoke('Add "- [ ] Review code" to the todo notebook') +await agent.invoke('Add "- [ ] Write tests" to the todo notebook') + +// State is accessible via the agent +console.log(agent.appState.get('notebooks')) +// Output: { todo: '# Tasks\n- [ ] Review code\n- [ ] Write tests' } +``` + +### Saving and Restoring State + +Save notebook state across application restarts: + +```typescript +// Save the current state +const savedState = agent.appState.getAll() + +// Later, create a new agent with the saved state +const restoredAgent = new Agent({ + model: new BedrockModel({ + region: 'us-east-1', + }), + tools: [notebook], + appState: savedState, // Restore previous notebooks +}) + +// All notebooks are immediately available +await restoredAgent.invoke('List all notebooks') +await restoredAgent.invoke('Read the todo notebook') +``` + +## Notebook Operations + +The agent can perform these operations through natural language: + +- **Create**: "Create a notebook called 'notes' with '# My Notes'" +- **List**: "List all notebooks" +- **Read**: "Read the notes notebook" or "Read lines 5-10 from notes" +- **Write**: + - Replace: "Replace 'old text' with 'new text' in notes" + - Insert: "Add 'new line' to the notes notebook" +- **Clear**: "Clear the notes notebook" + +## Example: Building a Task Manager + +```typescript +const agent = new Agent({ + model: new BedrockModel({ + region: 'us-east-1', + }), + tools: [notebook], +}) + +// Create a task list +await agent.invoke('Create a notebook called "tasks" with "# Daily Tasks\n\n## Todo\n"') + +// Add tasks +await agent.invoke('Add "- [ ] Morning standup" to the tasks notebook') +await agent.invoke('Add "- [ ] Code review" to the tasks notebook') +await agent.invoke('Add "- [ ] Update documentation" to the tasks notebook') + +// Complete a task +await agent.invoke('Replace "- [ ] Morning standup" with "- [x] Morning standup" in tasks') + +// Check progress +const result = await agent.invoke('Read the tasks notebook') + +// Save state for tomorrow +const taskState = agent.appState.getAll() +// Store taskState in your database/file system +``` + +## Direct Tool Usage + +You can also use the notebook tool directly without an agent: + +```typescript +import { notebook } from '@strands-agents/sdk/vended-tools/notebook' +import { StateStore } from '@strands-agents/sdk' + +const state = new StateStore({ notebooks: {} }) +const agent = { appState: state } +const context = { + agent, + toolUse: { name: 'notebook', toolUseId: 'test', input: {} }, +} + +// Create and write to a notebook +await notebook.invoke( + { + mode: 'create', + name: 'direct', + newStr: 'Direct notebook content', + }, + context +) + +// Read the notebook +const content = await notebook.invoke( + { + mode: 'read', + name: 'direct', + }, + context +) +``` + +## Key Features + +- **Multiple Notebooks**: Manage multiple named notebooks simultaneously +- **Automatic Persistence**: State persists within agent sessions automatically +- **Natural Language**: Interact with notebooks using natural language through the agent +- **State Management**: Save and restore notebook state across application restarts +- **Type Safety**: Full TypeScript support with runtime validation +- **Universal**: Works in both browser and server environments + +## API Reference + +### Input Schema + +```typescript +type NotebookInput = { + mode: 'create' | 'list' | 'read' | 'write' | 'clear' + name?: string // Notebook name (defaults to 'default') + newStr?: string // Content for create/write operations + oldStr?: string // Text to replace (write mode) + insertLine?: string | number // Line to insert after (write mode) + readRange?: [number, number] // Line range for read (1-indexed) +} +``` + +### State Structure + +```typescript +interface NotebookState { + notebooks: Record // name -> content mapping +} +``` + +## License + +Same license as the Strands SDK. diff --git a/strands-ts/src/vended-tools/notebook/__tests__/notebook.test.ts b/strands-ts/src/vended-tools/notebook/__tests__/notebook.test.ts new file mode 100644 index 0000000000..e00147f273 --- /dev/null +++ b/strands-ts/src/vended-tools/notebook/__tests__/notebook.test.ts @@ -0,0 +1,506 @@ +import { describe, it, expect } from 'vitest' +import { notebook } from '../notebook.js' +import type { NotebookState } from '../types.js' +import type { ToolContext } from '../../../index.js' +import { StateStore } from '../../../state-store.js' +import { createMockAgent } from '../../../__fixtures__/agent-helpers.js' + +describe('notebook tool', () => { + // Helper to create fresh state and context for each test + const createFreshContext = (): { state: StateStore; context: ToolContext } => { + const agent = createMockAgent({ appState: { notebooks: {} } }) + const context: ToolContext = { + toolUse: { + name: 'notebook', + toolUseId: 'test-id', + input: {}, + }, + agent, + invocationState: {}, + interrupt: () => { + throw new Error('interrupt not available in mock context') + }, + } + return { state: agent.appState, context } + } + + describe('create oper ation', () => { + it('creates an empty notebook with default name', async () => { + const { state, context } = createFreshContext() + const result = await notebook.invoke({ mode: 'create' }, context) + expect(result).toBe("Created notebook 'default' (empty)") + const notebooks = state.get('notebooks') + expect(notebooks!.default).toBe('') + }) + + it('creates an empty notebook with custom name', async () => { + const { state, context } = createFreshContext() + const result = await notebook.invoke({ mode: 'create', name: 'notes' }, context) + expect(result).toBe("Created notebook 'notes' (empty)") + const notebooks = state.get('notebooks') + expect(notebooks!.notes).toBe('') + }) + + it('creates a notebook with initial content', async () => { + const { state, context } = createFreshContext() + const content = '# My Notes\n\nFirst entry' + const result = await notebook.invoke({ mode: 'create', name: 'notes', newStr: content }, context) + expect(result).toBe("Created notebook 'notes' with specified content") + const notebooks = state.get('notebooks') + expect(notebooks!.notes).toBe(content) + }) + + it('overwrites existing notebook on create', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { notes: 'Old content' }) + const result = await notebook.invoke({ mode: 'create', name: 'notes', newStr: 'New content' }, context) + expect(result).toBe("Created notebook 'notes' with specified content") + const notebooks = state.get('notebooks') + expect(notebooks!.notes).toBe('New content') + }) + }) + + describe('list operation', () => { + it('lists default notebook when initialized', async () => { + const { state, context } = createFreshContext() + // Initialize notebooks with default + state.set('notebooks', { default: '' }) + const result = await notebook.invoke({ mode: 'list' }, context) + expect(result).toContain('default: Empty') + }) + + it('lists multiple notebooks with line counts', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { + default: '', + notes: 'Line 1\nLine 2\nLine 3', + todo: 'Single line', + }) + + const result = await notebook.invoke({ mode: 'list' }, context) + expect(result).toContain('default: Empty') + expect(result).toContain('notes: 3 lines') + expect(result).toContain('todo: 1 lines') + }) + }) + + describe('read operation', () => { + it('reads entire notebook with default name', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: 'Line 1\nLine 2\nLine 3\nLine 4\nLine 5' }) + const result = await notebook.invoke({ mode: 'read' }, context) + expect(result).toBe('Line 1\nLine 2\nLine 3\nLine 4\nLine 5') + }) + + it('reads entire notebook with custom name', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { notes: 'Content here' }) + const result = await notebook.invoke({ mode: 'read', name: 'notes' }, context) + expect(result).toBe('Content here') + }) + + it('reads empty notebook', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { empty: '' }) + const result = await notebook.invoke({ mode: 'read', name: 'empty' }, context) + expect(result).toBe("Notebook 'empty' is empty") + }) + + it('throws error for non-existent notebook', async () => { + const { context } = createFreshContext() + await expect(notebook.invoke({ mode: 'read', name: 'missing' }, context)).rejects.toThrow( + "Notebook 'missing' not found" + ) + }) + + it('reads specific line range', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: 'Line 1\nLine 2\nLine 3\nLine 4\nLine 5' }) + const result = await notebook.invoke({ mode: 'read', readRange: [2, 4] }, context) + expect(result).toBe('2: Line 2\n3: Line 3\n4: Line 4') + }) + + it('reads line range with negative start index', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: 'Line 1\nLine 2\nLine 3\nLine 4\nLine 5' }) + const result = await notebook.invoke({ mode: 'read', readRange: [-3, 5] }, context) + expect(result).toBe('3: Line 3\n4: Line 4\n5: Line 5') + }) + + it('reads line range with negative end index', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: 'Line 1\nLine 2\nLine 3\nLine 4\nLine 5' }) + const result = await notebook.invoke({ mode: 'read', readRange: [1, -2] }, context) + expect(result).toBe('1: Line 1\n2: Line 2\n3: Line 3\n4: Line 4') + }) + + it('reads line range with both negative indices', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: 'Line 1\nLine 2\nLine 3\nLine 4\nLine 5' }) + const result = await notebook.invoke({ mode: 'read', readRange: [-2, -1] }, context) + expect(result).toBe('4: Line 4\n5: Line 5') + }) + + it('returns no valid lines for out of range', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: 'Line 1\nLine 2\nLine 3\nLine 4\nLine 5' }) + const result = await notebook.invoke({ mode: 'read', readRange: [10, 20] }, context) + expect(result).toBe('No valid lines found in range') + }) + }) + + describe('write operation - string replacement', () => { + it('replaces text in default notebook', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: '# Todo List\n\n[ ] Task 1\n[ ] Task 2\n[x] Task 3' }) + const result = await notebook.invoke( + { + mode: 'write', + oldStr: '[ ] Task 1', + newStr: '[x] Task 1', + }, + context + ) + expect(result).toBe("Replaced text in notebook 'default'") + const notebooks = state.get('notebooks') + expect(notebooks!.default).toBe('# Todo List\n\n[x] Task 1\n[ ] Task 2\n[x] Task 3') + }) + + it('replaces text in custom notebook', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { notes: 'Original text' }) + const result = await notebook.invoke( + { + mode: 'write', + name: 'notes', + oldStr: 'Original', + newStr: 'Updated', + }, + context + ) + expect(result).toBe("Replaced text in notebook 'notes'") + const notebooks = state.get('notebooks') + expect(notebooks!.notes).toBe('Updated text') + }) + + it('replaces multiline text', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: '# Todo List\n\n[ ] Task 1\n[ ] Task 2\n[x] Task 3' }) + const result = await notebook.invoke( + { + mode: 'write', + oldStr: '[ ] Task 1\n[ ] Task 2', + newStr: '[x] Task 1\n[x] Task 2', + }, + context + ) + expect(result).toBe("Replaced text in notebook 'default'") + const notebooks = state.get('notebooks') + expect(notebooks!.default).toBe('# Todo List\n\n[x] Task 1\n[x] Task 2\n[x] Task 3') + }) + + it('preserves dollar sign patterns in newStr literally', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: 'const value = getPrice()' }) + const result = await notebook.invoke( + { + mode: 'write', + oldStr: 'getPrice()', + newStr: '$& is not $1 or $$', + }, + context + ) + expect(result).toBe("Replaced text in notebook 'default'") + const notebooks = state.get('notebooks') + expect(notebooks!.default).toBe('const value = $& is not $1 or $$') + }) + + it('throws error if old string not found', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: '# Todo List\n\n[ ] Task 1\n[ ] Task 2\n[x] Task 3' }) + await expect( + notebook.invoke( + { + mode: 'write', + oldStr: 'Nonexistent', + newStr: 'New', + }, + context + ) + ).rejects.toThrow("String 'Nonexistent' not found in notebook 'default'") + }) + + it('throws error for non-existent notebook', async () => { + const { context } = createFreshContext() + await expect( + notebook.invoke( + { + mode: 'write', + name: 'missing', + oldStr: 'Old', + newStr: 'New', + }, + context + ) + ).rejects.toThrow("Notebook 'missing' not found") + }) + }) + + describe('write operation - line insertion', () => { + it('inserts after line number', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: 'Line 1\nLine 2\nLine 3' }) + const result = await notebook.invoke( + { + mode: 'write', + insertLine: 2, + newStr: 'Inserted line', + }, + context + ) + expect(result).toBe("Inserted text at line 3 in notebook 'default'") + const notebooks = state.get('notebooks') + expect(notebooks!.default).toBe('Line 1\nLine 2\nInserted line\nLine 3') + }) + + it('inserts at beginning (after line 0)', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: 'Line 1\nLine 2\nLine 3' }) + const result = await notebook.invoke( + { + mode: 'write', + insertLine: 0, + newStr: 'First line', + }, + context + ) + expect(result).toBe("Inserted text at line 1 in notebook 'default'") + const notebooks = state.get('notebooks') + expect(notebooks!.default).toBe('First line\nLine 1\nLine 2\nLine 3') + }) + + it('appends to end with negative index', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: 'Line 1\nLine 2\nLine 3' }) + const result = await notebook.invoke( + { + mode: 'write', + insertLine: -1, + newStr: 'Last line', + }, + context + ) + expect(result).toBe("Inserted text at line 4 in notebook 'default'") + const notebooks = state.get('notebooks') + expect(notebooks!.default).toBe('Line 1\nLine 2\nLine 3\nLast line') + }) + + it('inserts after negative line index', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: 'Line 1\nLine 2\nLine 3' }) + const result = await notebook.invoke( + { + mode: 'write', + insertLine: -2, + newStr: 'Before last', + }, + context + ) + expect(result).toBe("Inserted text at line 3 in notebook 'default'") + const notebooks = state.get('notebooks') + expect(notebooks!.default).toBe('Line 1\nLine 2\nBefore last\nLine 3') + }) + + it('inserts after text search', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: 'Line 1\nLine 2\nLine 3' }) + const result = await notebook.invoke( + { + mode: 'write', + insertLine: 'Line 1', + newStr: 'After Line 1', + }, + context + ) + expect(result).toBe("Inserted text at line 2 in notebook 'default'") + const notebooks = state.get('notebooks') + expect(notebooks!.default).toBe('Line 1\nAfter Line 1\nLine 2\nLine 3') + }) + + it('inserts after partial text match', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: 'Line 1\nLine 2\nLine 3' }) + const result = await notebook.invoke( + { + mode: 'write', + insertLine: '2', + newStr: 'After match', + }, + context + ) + expect(result).toBe("Inserted text at line 3 in notebook 'default'") + const notebooks = state.get('notebooks') + expect(notebooks!.default).toBe('Line 1\nLine 2\nAfter match\nLine 3') + }) + + it('throws error if search text not found', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: 'Line 1\nLine 2\nLine 3' }) + await expect( + notebook.invoke( + { + mode: 'write', + insertLine: 'Nonexistent', + newStr: 'New line', + }, + context + ) + ).rejects.toThrow("Text 'Nonexistent' not found in notebook 'default'") + }) + + it('throws error for line number out of range', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: 'Line 1\nLine 2\nLine 3' }) + await expect( + notebook.invoke( + { + mode: 'write', + insertLine: 100, + newStr: 'New line', + }, + context + ) + ).rejects.toThrow('Line number out of range') + }) + + it('inserts into custom notebook', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { notes: 'First\nSecond' }) + const result = await notebook.invoke( + { + mode: 'write', + name: 'notes', + insertLine: 1, + newStr: 'Middle', + }, + context + ) + expect(result).toBe("Inserted text at line 2 in notebook 'notes'") + const notebooks = state.get('notebooks') + expect(notebooks!.notes).toBe('First\nMiddle\nSecond') + }) + }) + + describe('clear operation', () => { + it('clears default notebook', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: 'Some content' }) + const result = await notebook.invoke({ mode: 'clear' }, context) + expect(result).toBe("Cleared notebook 'default'") + const notebooks = state.get('notebooks') + expect(notebooks!.default).toBe('') + }) + + it('clears custom notebook', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { notes: 'More content' }) + const result = await notebook.invoke({ mode: 'clear', name: 'notes' }, context) + expect(result).toBe("Cleared notebook 'notes'") + const notebooks = state.get('notebooks') + expect(notebooks!.notes).toBe('') + }) + + it('throws error for non-existent notebook', async () => { + const { context } = createFreshContext() + await expect(notebook.invoke({ mode: 'clear', name: 'missing' }, context)).rejects.toThrow( + "Notebook 'missing' not found" + ) + }) + + it('clearing does not affect other notebooks', async () => { + const { state, context } = createFreshContext() + state.set('notebooks', { default: 'Some content', notes: 'More content' }) + await notebook.invoke({ mode: 'clear', name: 'notes' }, context) + const notebooks = state.get('notebooks') + expect(notebooks!.default).toBe('Some content') + }) + }) + + describe('state persistence', () => { + it('persists notebooks across operations', async () => { + const { state, context } = createFreshContext() + // Create notebook + await notebook.invoke({ mode: 'create', name: 'notes', newStr: 'Initial' }, context) + let notebooks = state.get('notebooks') + expect(notebooks!.notes).toBe('Initial') + + // Write to notebook - use oldStr/newStr instead of insertLine for appending + await notebook.invoke({ mode: 'write', name: 'notes', oldStr: 'Initial', newStr: 'Initial\nAdded' }, context) + notebooks = state.get('notebooks') + expect(notebooks!.notes).toBe('Initial\nAdded') + + // Read notebook + const content = await notebook.invoke({ mode: 'read', name: 'notes' }, context) + expect(content).toBe('Initial\nAdded') + + // Verify state is still intact + notebooks = state.get('notebooks') + expect(notebooks!.notes).toBe('Initial\nAdded') + }) + + it('initializes default notebook if state is empty', async () => { + const { state, context } = createFreshContext() + const result = await notebook.invoke({ mode: 'list' }, context) + expect(result).toContain('default: Empty') + const notebooks = state.get('notebooks') + expect(notebooks!.default).toBe('') + }) + }) + + describe('validation errors', () => { + it('requires context', async () => { + await expect(notebook.invoke({ mode: 'list' })).rejects.toThrow('Tool context is required') + }) + + it('rejects write without newStr for replacement', async () => { + const { context } = createFreshContext() + await expect( + notebook.invoke( + { + mode: 'write', + oldStr: 'Old', + // Missing newStr + } as any, + context + ) + ).rejects.toThrow() + }) + + it('rejects write without newStr for insertion', async () => { + const { context } = createFreshContext() + await expect( + notebook.invoke( + { + mode: 'write', + insertLine: 1, + // Missing newStr + } as any, + context + ) + ).rejects.toThrow() + }) + + it('rejects write without valid operation parameters', async () => { + const { context } = createFreshContext() + await expect( + notebook.invoke( + { + mode: 'write', + // Missing both replacement and insertion params + } as any, + context + ) + ).rejects.toThrow() + }) + }) +}) diff --git a/strands-ts/src/vended-tools/notebook/index.ts b/strands-ts/src/vended-tools/notebook/index.ts new file mode 100644 index 0000000000..267eed2cce --- /dev/null +++ b/strands-ts/src/vended-tools/notebook/index.ts @@ -0,0 +1,6 @@ +/** + * Notebook tool for managing text notebooks within agent invocations. + */ + +export { notebook } from './notebook.js' +export type { NotebookState, NotebookInput } from './types.js' diff --git a/strands-ts/src/vended-tools/notebook/notebook.ts b/strands-ts/src/vended-tools/notebook/notebook.ts new file mode 100644 index 0000000000..a699f97818 --- /dev/null +++ b/strands-ts/src/vended-tools/notebook/notebook.ts @@ -0,0 +1,262 @@ +import { tool } from '../../index.js' +import { z } from 'zod' +import type { NotebookState } from './types.js' + +/** + * Zod schema for notebook input validation. + */ +const notebookInputSchema = z + .object({ + mode: z + .enum(['create', 'list', 'read', 'write', 'clear']) + .describe('The operation to perform: `create`, `list`, `read`, `write`, `clear`.'), + name: z.string().optional().describe('Name of the notebook to operate on. Defaults to "default".'), + newStr: z.string().optional().describe('New string for replacement or insertion operations.'), + readRange: z + .array(z.number()) + .optional() + .describe('Optional parameter of `view` command. Line range to show [start, end]. Supports negative indices.'), + oldStr: z.string().optional().describe('String to replace in write mode when doing text replacement.'), + insertLine: z + .union([z.string(), z.number()]) + .optional() + .describe( + 'Line number (int) or search text (str) for insertion point in write mode.\nSupports negative indices.' + ), + }) + .refine( + (data) => { + // Validate write mode requirements + if (data.mode === 'write') { + const hasReplacement = data.oldStr !== undefined && data.newStr !== undefined + const hasInsertion = data.insertLine !== undefined && data.newStr !== undefined + return hasReplacement || hasInsertion + } + return true + }, + { + message: + 'Write operation requires either (oldStr + newStr) for replacement or (insertLine + newStr) for insertion', + } + ) + +/** + * Notebook tool for managing persistent text notebooks. + * + * Notebooks are stored in agent state under the 'notebooks' key and persist within an agent session. + * Supports create, list, read, write (replace/insert), and clear operations. + * + * @example + * ```typescript + * // With agent + * const agent = new Agent({ tools: [notebook] }) + * await agent.invoke('Create a notebook called "notes"') + * await agent.invoke('Add "- Task 1" to notes') + * + * // Direct usage + * await notebook.invoke( + * { mode: 'create', name: 'notes', newStr: '# Notes' }, + * { agent: agent, toolUse: { name: 'notebook', toolUseId: 'test', input: {} } } + * ) + * ``` + */ +export const notebook = tool({ + name: 'notebook', + description: + 'Manages text notebooks for note-taking and documentation. Supports create, list, read, write (replace or insert), and clear operations. Notebooks persist within the agent invocation.', + inputSchema: notebookInputSchema, + callback: (input, context) => { + if (!context) { + throw new Error('Tool context is required for notebook operations') + } + + // Get notebooks from state, or initialize if not present + let notebooks = context.agent.appState.get('notebooks') + + if (!notebooks) { + notebooks = {} + } + + // Ensure default notebook exists + if (Object.keys(notebooks).length === 0) { + notebooks.default = '' + } + + let result: string + + switch (input.mode) { + case 'create': + result = handleCreate(notebooks, input.name ?? 'default', input.newStr) + break + + case 'list': + result = handleList(notebooks) + break + + case 'read': + result = handleRead(notebooks, input.name ?? 'default', input.readRange) + break + + case 'write': + result = handleWrite(notebooks, input.name ?? 'default', input.oldStr, input.newStr, input.insertLine) + break + + case 'clear': + result = handleClear(notebooks, input.name ?? 'default') + break + + default: + throw new Error(`Unknown mode: ${input.mode}`) + } + + // Persist notebooks back to state + context.agent.appState.set('notebooks', notebooks) + + return result + }, +}) + +/** + * Handles create operation. + */ +function handleCreate(notebooks: Record, name: string, newStr?: string): string { + notebooks[name] = newStr ?? '' + const message = `Created notebook '${name}'${newStr ? ' with specified content' : ' (empty)'}` + return message +} + +/** + * Handles list operation. + */ +function handleList(notebooks: Record): string { + const notebookNames = Object.keys(notebooks) + const details = notebookNames + .map((name) => { + const lineCount = notebooks[name] ? notebooks[name].split('\n').length : 0 + const status = lineCount === 0 ? 'Empty' : `${lineCount} lines` + return `- ${name}: ${status}` + }) + .join('\n') + + return `Available notebooks:\n${details}` +} + +/** + * Handles read operation. + */ +function handleRead(notebooks: Record, name: string, readRange?: number[]): string { + if (!(name in notebooks)) { + throw new Error(`Notebook '${name}' not found`) + } + + const content = notebooks[name]! + + if (!readRange) { + return content || `Notebook '${name}' is empty` + } + + // Handle line range reading + const lines = content.split('\n') + let start = readRange[0] + let end = readRange[1] + + if (start === undefined || end === undefined) { + throw new Error('`readRange` must be a list of two integers: `[start, end]`') + } + + // Handle negative indices + if (start < 0) { + start = lines.length + start + 1 + } + if (end < 0) { + end = lines.length + end + 1 + } + + const selectedLines: string[] = [] + for (let lineNum = start; lineNum <= end; lineNum++) { + if (lineNum >= 1 && lineNum <= lines.length) { + selectedLines.push(`${lineNum}: ${lines[lineNum - 1]}`) + } + } + + return selectedLines.length > 0 ? selectedLines.join('\n') : 'No valid lines found in range' +} + +/** + * Handles write operation (both string replacement and line insertion). + */ +function handleWrite( + notebooks: Record, + name: string, + oldStr?: string, + newStr?: string, + insertLine?: string | number +): string { + if (!(name in notebooks)) { + throw new Error(`Notebook '${name}' not found`) + } + + // String replacement mode + if (oldStr !== undefined && newStr !== undefined) { + if (!notebooks[name]!.includes(oldStr)) { + throw new Error(`String '${oldStr}' not found in notebook '${name}'`) + } + + notebooks[name] = notebooks[name]!.replace(oldStr, () => newStr) + return `Replaced text in notebook '${name}'` + } + + // Line insertion mode + if (insertLine !== undefined && newStr !== undefined) { + const lines = notebooks[name]!.split('\n') + let lineNum: number + + // Handle string search + if (typeof insertLine === 'string') { + lineNum = -1 + for (let i = 0; i < lines.length; i++) { + if (lines[i]!.includes(insertLine)) { + lineNum = i + break + } + } + if (lineNum === -1) { + throw new Error(`Text '${insertLine}' not found in notebook '${name}'`) + } + } else { + // Handle numeric index with negative support + if (insertLine < 0) { + lineNum = lines.length + insertLine + } else { + lineNum = insertLine - 1 + } + } + + // Validate line number range (allow -1 for prepending before first line) + if (lineNum < -1 || lineNum > lines.length) { + throw new Error(`Line number out of range`) + } + + // Insert at the calculated position + lines.splice(lineNum + 1, 0, newStr) + const updatedContent = lines.join('\n') + Object.assign(notebooks, { [name]: updatedContent }) + + return `Inserted text at line ${lineNum + 2} in notebook '${name}'` + } + + throw new Error('Invalid write operation') +} + +/** + * Handles clear operation. + */ +function handleClear(notebooks: Record, name: string): string { + const notebook = notebooks[name] + if (notebook === undefined) { + throw new Error(`Notebook '${name}' not found`) + } + + notebooks[name] = '' + return `Cleared notebook '${name}'` +} diff --git a/strands-ts/src/vended-tools/notebook/types.ts b/strands-ts/src/vended-tools/notebook/types.ts new file mode 100644 index 0000000000..0b6642a834 --- /dev/null +++ b/strands-ts/src/vended-tools/notebook/types.ts @@ -0,0 +1,85 @@ +/** + * State structure for notebook storage. + * Notebooks are stored in agent state under the 'notebooks' key. + */ +export interface NotebookState { + /** + * Map of notebook names to their content. + * Each notebook stores plain text content with newline-separated lines. + */ + notebooks: Record +} + +/** + * Input parameters for create operation. + * - mode: Operation mode, must be 'create' + * - name: Name of the notebook to create + * - newStr: Optional initial content for the notebook + */ +export interface CreateInput { + mode: 'create' + name?: string + newStr?: string +} + +/** + * Input parameters for list operation. + */ +export interface ListInput { + mode: 'list' +} + +/** + * Input parameters for read operation. + * - mode: Operation mode, must be 'read' + * - name: Name of the notebook to read + * - readRange: Optional line range [start, end] to read. Supports negative indices. + */ +export interface ReadInput { + mode: 'read' + name?: string + readRange?: [number, number] +} + +/** + * Input parameters for write operation (string replacement). + * - mode: Operation mode, must be 'write' + * - name: Name of the notebook to write to + * - oldStr: String to find and replace + * - newStr: Replacement string + */ +export interface WriteReplaceInput { + mode: 'write' + name?: string + oldStr: string + newStr: string +} + +/** + * Input parameters for write operation (line insertion). + * - mode: Operation mode, must be 'write' + * - name: Name of the notebook to write to + * - insertLine: Line number (supports negative indices) or search text for insertion point + * - newStr: Text to insert + */ +export interface WriteInsertInput { + mode: 'write' + name?: string + insertLine: string | number + newStr: string +} + +/** + * Input parameters for clear operation. + * - mode: Operation mode, must be 'clear' + * - name: Name of the notebook to clear + */ +export interface ClearInput { + mode: 'clear' + name?: string +} + +/** + * Union type of all valid notebook inputs. + */ +export type NotebookInput = CreateInput | ListInput | ReadInput | WriteReplaceInput | WriteInsertInput | ClearInput diff --git a/strands-ts/test/integ/__fixtures__/_setup-global.ts b/strands-ts/test/integ/__fixtures__/_setup-global.ts new file mode 100644 index 0000000000..256b453044 --- /dev/null +++ b/strands-ts/test/integ/__fixtures__/_setup-global.ts @@ -0,0 +1,224 @@ +/** + * Global setup that runs once before all integration tests and possibly runs in the *parent* process. + * + * _setup-test on the other hand runs in the *child* process. + */ + +import { SecretsManagerClient, GetSecretValueCommand } from '@aws-sdk/client-secrets-manager' +import { fromNodeProviderChain } from '@aws-sdk/credential-providers' +import express from 'express' +import type { TestProject } from 'vitest/node' +import type { ProvidedContext } from 'vitest' + +import { Agent } from '../../../src/agent/agent.js' +import { A2AExpressServer } from '../../../src/a2a/express-server.js' +import { BedrockModel } from '../../../src/models/bedrock.js' + +/** + * Load API keys as environment variables from AWS Secrets Manager + */ +async function loadApiKeysFromSecretsManager(): Promise { + const client = new SecretsManagerClient({ + region: process.env.AWS_REGION || 'us-east-1', + }) + + try { + const secretName = 'model-provider-api-key' + const command = new GetSecretValueCommand({ + SecretId: secretName, + }) + const response = await client.send(command) + + if (response.SecretString) { + const secret = JSON.parse(response.SecretString) + // Only add API keys for currently supported providers + const supportedProviders = ['openai', 'anthropic', 'gemini'] + Object.entries(secret).forEach(([key, value]) => { + if (supportedProviders.includes(key.toLowerCase())) { + process.env[`${key.toUpperCase()}_API_KEY`] = String(value) + } + }) + } + } catch (e) { + console.warn('Error retrieving secret', e) + } + + /* + * Validate that required environment variables are set when running in GitHub Actions. + * This prevents tests from being unintentionally skipped due to missing credentials. + */ + if (process.env.GITHUB_ACTIONS !== 'true') { + console.warn('Tests running outside GitHub Actions, skipping required provider validation') + return + } + + const requiredProviders: Set = new Set(['OPENAI_API_KEY', 'ANTHROPIC_API_KEY']) + + for (const provider of requiredProviders) { + if (!process.env[provider]) { + throw new Error(`Missing required environment variables for ${provider}`) + } + } +} + +/** + * Perform shared setup for the integration tests. + */ +export async function setup(project: TestProject): Promise<() => void> { + console.log('Global setup: Loading API keys from Secrets Manager...') + await loadApiKeysFromSecretsManager() + console.log('Global setup: API keys loaded into environment') + + const isCI = !!globalThis.process.env.CI + + project.provide('isBrowser', project.isBrowserEnabled()) + project.provide('isCI', isCI) + project.provide('provider-openai', await getOpenAITestContext(isCI)) + project.provide('provider-bedrock', await getBedrockTestContext(isCI)) + project.provide('provider-anthropic', await getAnthropicTestContext(isCI)) + project.provide('provider-gemini', await getGeminiTestContext(isCI)) + + const a2aContext = await getA2AServerContext(project) + project.provide('a2a-server', { shouldSkip: a2aContext.shouldSkip, url: a2aContext.url }) + + return () => { + a2aContext.abort?.() + } +} + +async function getOpenAITestContext(isCI: boolean): Promise { + const apiKey = process.env.OPENAI_API_KEY + const shouldSkip = !apiKey + + if (shouldSkip) { + console.log('⏭️ OpenAI API key not available - integration tests will be skipped') + if (isCI) { + throw new Error('CI/CD should be running all tests') + } + } else { + console.log('⏭️ OpenAI API key available - integration tests will run') + } + + return { + apiKey: apiKey, + shouldSkip: shouldSkip, + } +} + +async function getAnthropicTestContext(isCI: boolean): Promise { + const apiKey = process.env.ANTHROPIC_API_KEY + const shouldSkip = !apiKey + + if (shouldSkip) { + console.log('⏭️ Anthropic API key not available - integration tests will be skipped') + if (isCI) { + throw new Error('CI/CD should be running all tests') + } + } else { + console.log('⏭️ Anthropic API key available - integration tests will run') + } + + return { + apiKey: apiKey, + shouldSkip: shouldSkip, + } +} + +async function getBedrockTestContext(isCI: boolean): Promise { + try { + const credentialProvider = fromNodeProviderChain() + const credentials = await credentialProvider() + console.log('⏭️ Bedrock credentials available - integration tests will run') + return { + shouldSkip: false, + credentials: credentials, + } + } catch { + console.log('⏭️ Bedrock credentials not available - integration tests will be skipped') + if (isCI) { + throw new Error('CI/CD should be running all tests') + } + return { + shouldSkip: true, + credentials: undefined, + } + } +} + +async function getGeminiTestContext(_isCI: boolean): Promise { + const apiKey = process.env.GEMINI_API_KEY + const shouldSkip = !apiKey + + if (shouldSkip) { + console.log('⏭️ Gemini API key not available - integration tests will be skipped') + // Note: Gemini is not required in CI for now, so we don't throw an error + } else { + console.log('⏭️ Gemini API key available - integration tests will run') + } + + return { + apiKey: apiKey, + shouldSkip: shouldSkip, + } +} + +async function getA2AServerContext( + project: TestProject +): Promise void }> { + const { testFiles } = await project.globTestFiles() + const hasA2ATests = testFiles.some((f) => f.includes('/a2a/')) + + if (!hasA2ATests) { + return { shouldSkip: true, url: undefined } + } + + let credentials + try { + const credentialProvider = fromNodeProviderChain() + credentials = await credentialProvider() + } catch { + console.log('⏭️ A2A server not available (no Bedrock credentials) - A2A integration tests will be skipped') + return { shouldSkip: true, url: undefined } + } + + const model = new BedrockModel({ clientConfig: { credentials } }) + const agent = new Agent({ + model, + printer: false, + systemPrompt: 'You are a helpful assistant. Always respond in a single short sentence.', + }) + + const a2aServer = new A2AExpressServer({ + agent, + name: 'Test A2A Agent', + description: 'Integration test agent', + }) + + // Use createMiddleware() with CORS headers so browser integ tests can reach the server. + // Browser tests run on a different port (Vitest dev server), making this a cross-origin request. + const app = express() + app.use((_req, res, next) => { + res.setHeader('Access-Control-Allow-Origin', '*') + res.setHeader('Access-Control-Allow-Methods', '*') + res.setHeader('Access-Control-Allow-Headers', '*') + next() + }) + app.use(a2aServer.createMiddleware()) + + return new Promise((resolve, reject) => { + const server = app.listen(0, '127.0.0.1', () => { + const addr = server.address() as { port: number } + const url = `http://127.0.0.1:${addr.port}` + // Update the agent card URL to reflect the actual bound port. + // createMiddleware() doesn't do this automatically (unlike serve()). + a2aServer.agentCard.url = url + console.log(`⏭️ A2A server started on ${url}`) + resolve({ + shouldSkip: false, + url, + abort: () => server.close(), + }) + }) + server.on('error', reject) + }) +} diff --git a/strands-ts/test/integ/__fixtures__/_setup-test.ts b/strands-ts/test/integ/__fixtures__/_setup-test.ts new file mode 100644 index 0000000000..bbb57deb0b --- /dev/null +++ b/strands-ts/test/integ/__fixtures__/_setup-test.ts @@ -0,0 +1,21 @@ +/** + * Test setup that runs once before all integration tests, but in the *child* process. + * + * _setup-global on the other hand runs in the *parent* process. + */ + +import { beforeAll } from 'vitest' +import { configureLogging } from '$/sdk/logging/index.js' +import { isCI } from './test-helpers.js' + +beforeAll(() => { + // When running under CI/CD, preserve all logs including debug + if (isCI()) { + configureLogging({ + debug: (...args: unknown[]) => console.debug(...args), + info: (...args: unknown[]) => console.info(...args), + warn: (...args: unknown[]) => console.warn(...args), + error: (...args: unknown[]) => console.error(...args), + }) + } +}) diff --git a/strands-ts/test/integ/__fixtures__/model-providers.ts b/strands-ts/test/integ/__fixtures__/model-providers.ts new file mode 100644 index 0000000000..7f000ded56 --- /dev/null +++ b/strands-ts/test/integ/__fixtures__/model-providers.ts @@ -0,0 +1,292 @@ +/** + * Contains helpers for creating various model providers that work both in node & the browser + */ + +import { inject } from 'vitest' +import { BedrockModel, type BedrockModelOptions } from '$/sdk/models/bedrock.js' +import { OpenAIModel, type OpenAIModelOptions } from '$/sdk/models/openai/index.js' +import { AnthropicModel, type AnthropicModelOptions } from '$/sdk/models/anthropic.js' +import { GoogleModel, type GoogleModelOptions } from '$/sdk/models/google/model.js' +import { VercelModel, type VercelModelConfig } from '$/sdk/models/vercel.js' +import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock' +import { createOpenAI } from '@ai-sdk/openai' + +/** + * Feature support flags for model providers. + * Used to conditionally run tests based on model capabilities. + * + * TODO: after https://github.com/strands-agents/sdk-python/issues/780 this config should be in src not test + */ +export interface ProviderFeatures { + reasoning: boolean + tools: boolean + toolThinking: boolean + builtInTools: boolean + images: boolean + documents: boolean + video: boolean + citations: boolean +} + +export const bedrock = { + name: 'BedrockModel', + supports: { + reasoning: true, + tools: true, + toolThinking: false, + builtInTools: false, + images: true, + documents: true, + video: true, + citations: true, + } satisfies ProviderFeatures, + models: { + default: {}, + reasoning: { + modelId: 'us.anthropic.claude-sonnet-4-20250514-v1:0', + additionalRequestFields: { thinking: { type: 'enabled', budget_tokens: 1024 } }, + }, + video: { modelId: 'us.amazon.nova-pro-v1:0' }, + }, + get skip() { + return inject('provider-bedrock').shouldSkip + }, + createModel: (options: BedrockModelOptions = {}): BedrockModel => { + const credentials = inject('provider-bedrock')?.credentials + if (!credentials) { + throw new Error('No Bedrock credentials provided') + } + return new BedrockModel({ + ...options, + clientConfig: { ...(options.clientConfig ?? {}), credentials }, + }) + }, +} + +export const openai = { + name: 'OpenAIModel', + supports: { + reasoning: false, + tools: true, + toolThinking: false, + builtInTools: false, + images: true, + documents: true, + video: false, + citations: false, + } satisfies ProviderFeatures, + models: { + default: {}, + reasoning: { modelId: 'o4-mini' }, + video: {}, + }, + get skip() { + return inject('provider-openai').shouldSkip + }, + createModel: (config: Omit = {}): OpenAIModel => { + const apiKey = inject('provider-openai')?.apiKey + if (!apiKey) { + throw new Error('No OpenAI apiKey provided') + } + return new OpenAIModel({ + ...config, + api: 'chat', + apiKey, + clientConfig: { ...(config.clientConfig ?? {}), dangerouslyAllowBrowser: true }, + }) + }, +} + +export const openaiResponses = { + name: "OpenAIModel (api: 'responses')", + supports: { + reasoning: true, + tools: true, + toolThinking: false, + builtInTools: true, + images: true, + documents: true, + video: false, + citations: true, + } satisfies ProviderFeatures, + models: { + default: {}, + reasoning: { modelId: 'o4-mini' }, + video: {}, + }, + get skip() { + return inject('provider-openai').shouldSkip + }, + createModel: ( + config: Omit, 'api' | 'client'> = {} + ): OpenAIModel => { + const apiKey = inject('provider-openai')?.apiKey + if (!apiKey) { + throw new Error('No OpenAI apiKey provided') + } + return new OpenAIModel({ + ...config, + api: 'responses', + apiKey, + clientConfig: { ...(config.clientConfig ?? {}), dangerouslyAllowBrowser: true }, + }) + }, +} + +export const anthropic = { + name: 'AnthropicModel', + supports: { + reasoning: true, + tools: true, + toolThinking: false, + builtInTools: false, + images: true, + documents: true, + video: false, + citations: false, + } satisfies ProviderFeatures, + models: { + default: {}, + reasoning: { + modelId: 'claude-sonnet-4-6', + params: { thinking: { type: 'enabled', budget_tokens: 1024 } }, + }, + video: {}, + }, + get skip() { + return inject('provider-anthropic').shouldSkip + }, + createModel: (config: AnthropicModelOptions = {}): AnthropicModel => { + const apiKey = inject('provider-anthropic')?.apiKey + if (!apiKey) { + throw new Error('No Anthropic apiKey provided') + } + + return new AnthropicModel({ + ...config, + apiKey: apiKey, + clientConfig: { + ...(config.clientConfig ?? {}), + dangerouslyAllowBrowser: true, + }, + }) + }, +} + +export const gemini = { + name: 'GoogleModel', + supports: { + reasoning: true, + tools: true, + toolThinking: true, + builtInTools: true, + images: true, + documents: true, + video: true, + citations: false, + } satisfies ProviderFeatures, + models: { + default: {}, + reasoning: { + modelId: 'gemini-2.5-flash', + params: { thinkingConfig: { thinkingBudget: 1024, includeThoughts: true } }, + }, + builtInTools: { + builtInTools: [{ codeExecution: {} }], + }, + video: {}, + }, + get skip() { + return inject('provider-gemini').shouldSkip + }, + createModel: (config: GoogleModelOptions = {}): GoogleModel => { + const apiKey = inject('provider-gemini').apiKey + if (!apiKey) { + throw new Error('No Gemini apiKey provided') + } + return new GoogleModel({ ...config, apiKey }) + }, +} + +export const vercelBedrock = { + name: 'VercelModel (Bedrock)', + supports: { + reasoning: true, + tools: true, + toolThinking: false, + builtInTools: false, + images: true, + documents: true, + video: false, + citations: false, + } satisfies ProviderFeatures, + models: { + default: {}, + reasoning: { + providerOptions: { + bedrock: { reasoningConfig: { type: 'enabled', budgetTokens: 1024 } }, + }, + }, + video: {}, + }, + get skip() { + return inject('provider-bedrock').shouldSkip + }, + createModel: (config: Partial = {}): VercelModel => { + const credentials = inject('provider-bedrock')?.credentials + if (!credentials) { + throw new Error('No Bedrock credentials provided') + } + const provider = createAmazonBedrock({ + ...(!credentials.expiration && { region: 'us-west-2' }), + credentialProvider: () => Promise.resolve(credentials), + }) + const { providerOptions, ...rest } = config as Partial & { + providerOptions?: Record + } + return new VercelModel({ + provider: provider('us.anthropic.claude-sonnet-4-20250514-v1:0'), + ...rest, + ...(providerOptions && { providerOptions }), + }) + }, +} + +export const vercelOpenAI = { + name: 'VercelModel (OpenAI)', + supports: { + reasoning: false, + tools: true, + toolThinking: false, + builtInTools: false, + images: true, + documents: true, + video: false, + citations: false, + } satisfies ProviderFeatures, + models: { + default: {}, + reasoning: { modelId: 'o1-mini' }, + video: {}, + }, + get skip() { + return inject('provider-openai').shouldSkip + }, + createModel: (config: Partial = {}): VercelModel => { + const apiKey = inject('provider-openai')?.apiKey + if (!apiKey) { + throw new Error('No OpenAI apiKey provided') + } + const provider = createOpenAI({ apiKey }) + const { providerOptions, ...rest } = config as Partial & { + providerOptions?: Record + } + return new VercelModel({ + provider: provider('gpt-4o'), + ...rest, + ...(providerOptions && { providerOptions }), + }) + }, +} + +export const allProviders = [bedrock, openai, anthropic, gemini, vercelBedrock, vercelOpenAI] diff --git a/strands-ts/test/integ/__fixtures__/model-test-helpers.ts b/strands-ts/test/integ/__fixtures__/model-test-helpers.ts new file mode 100644 index 0000000000..69375fd314 --- /dev/null +++ b/strands-ts/test/integ/__fixtures__/model-test-helpers.ts @@ -0,0 +1,20 @@ +import type { ContentBlock, Message } from '$/sdk/types/messages.js' + +/** + * Extracts plain text content from a Message object. + * + * This helper function handles different message formats by: + * - Extracting text from Message objects by filtering for textBlock content blocks + * - Joining multiple text blocks with newlines + * + * @param message - The message to extract text from. Message object with content blocks + * @returns The extracted text content as a string, or empty string if no content is found + */ +export const getMessageText = (message: Message): string => { + if (!message.content) return '' + + return message.content + .filter((block: ContentBlock) => block.type === 'textBlock') + .map((block) => block.text) + .join('\n') +} diff --git a/strands-ts/test/integ/__fixtures__/test-helpers.ts b/strands-ts/test/integ/__fixtures__/test-helpers.ts new file mode 100644 index 0000000000..3cbbc9b8e9 --- /dev/null +++ b/strands-ts/test/integ/__fixtures__/test-helpers.ts @@ -0,0 +1,117 @@ +import { inject } from 'vitest' +import { Agent, ToolResultBlock, tool } from '@strands-agents/sdk' +import type { AgentResult, InterruptResponseContentData, JSONValue, Message } from '@strands-agents/sdk' +import { z } from 'zod' + +/** + * Checks whether we're running tests in the browser. + */ +export const isInBrowser = () => { + return inject('isBrowser') +} + +export function isCI() { + return inject('isCI') +} + +/** + * Helper to load fixture files from Vite URL imports. + * Vite ?url imports return paths like '/test/integ/__resources__/file.png' in test environment. + * + * @param url - The URL from a Vite ?url import + * @returns The file contents as a Uint8Array + */ +export async function loadFixture(url: string): Promise { + if (isInBrowser()) { + const response = await globalThis.fetch(url) + const arrayBuffer = await response.arrayBuffer() + return new Uint8Array(arrayBuffer) + } else { + const { join } = await import('node:path') + const { readFile } = await import('node:fs/promises') + const relativePath = url.startsWith('/') ? url.slice(1) : url + const filePath = join(process.cwd(), relativePath) + return new Uint8Array(await readFile(filePath)) + } +} + +// ================================ +// Agent Message Helpers +// ================================ + +/** + * Checks if any message contains a toolUseBlock with the specified tool name. + */ +export function hasToolUse(messages: Message[], toolName: string): boolean { + return messages.some((msg) => msg.content.some((block) => block.type === 'toolUseBlock' && block.name === toolName)) +} + +/** + * Counts messages containing toolResultBlocks with the specified status. + */ +export function countToolResults(messages: Message[], status: 'success' | 'error'): number { + return messages.filter((msg) => + msg.content.some((block) => block.type === 'toolResultBlock' && block.status === status) + ).length +} + +/** + * Extracts text content from tool result blocks matching the given status. + */ +export function getToolResultText(messages: Message[], status?: 'success' | 'error'): string { + return messages + .filter((m) => m.role === 'user') + .flatMap((m) => + m.content.filter((b): b is ToolResultBlock => b.type === 'toolResultBlock' && (!status || b.status === status)) + ) + .flatMap((tr) => tr.content.filter((b) => b.type === 'textBlock').map((b) => b.text)) + .join(' ') +} + +/** + * Resumes an interrupted agent by responding to all pending interrupts, + * looping until the agent completes or a max iteration limit is reached. + */ +export async function resumeUntilDone( + agent: Agent, + result: AgentResult, + respond: (interrupt: { id: string; name: string; reason?: unknown }) => JSONValue, + maxRounds = 10 +): Promise { + let current = result + for (let i = 0; i < maxRounds && current.stopReason === 'interrupt'; i++) { + const responses: InterruptResponseContentData[] = current.interrupts!.map((interrupt) => ({ + interruptResponse: { + interruptId: interrupt.id, + response: respond(interrupt), + }, + })) + current = await agent.invoke(responses) + } + return current +} + +// ================================ +// Common Tool Fixtures +// ================================ + +export const timeTool = tool({ + name: 'time_tool', + description: 'Returns the current time. Always call this tool when asked about time.', + inputSchema: z.object({}), + callback: async () => '12:00', +}) + +export const weatherTool = tool({ + name: 'weather_tool', + description: 'Returns the current weather. Always call this tool when asked about weather.', + inputSchema: z.object({}), + callback: async () => 'sunny', +}) + +export const echoTool = tool({ + name: 'echo_tool', + description: 'Echoes back the given message. Always call this tool when asked to echo.', + inputSchema: z.object({ message: z.string().describe('The message to echo') }), + callback: async ({ message }) => `Echo: ${message}`, +}) diff --git a/strands-ts/test/integ/__fixtures__/test-mcp-server.ts b/strands-ts/test/integ/__fixtures__/test-mcp-server.ts new file mode 100644 index 0000000000..ef70763aff --- /dev/null +++ b/strands-ts/test/integ/__fixtures__/test-mcp-server.ts @@ -0,0 +1,257 @@ +/** + * Test MCP Server Implementation + * + * Provides a simple MCP server with test tools for integration testing. + * Supports stdio and HTTP transports. + */ + +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js' +import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js' +import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js' +import { createServer, type Server as HttpServer } from 'node:http' +import type { AddressInfo } from 'node:net' +import type { IncomingMessage, ServerResponse } from 'node:http' +import * as z from 'zod/v4' +import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' + +/** + * Creates a test MCP server with echo, calculator, and error_tool tools using registerTool. + */ +function createTestServer(): McpServer { + const server = new McpServer( + { + name: 'test-mcp-server', + version: '1.0.0', + }, + { + capabilities: { + tools: {}, + }, + } + ) + + // Register echo tool + server.registerTool( + 'echo', + { + title: 'Echo Tool', + description: 'Echoes back the input message', + inputSchema: { + message: z.string(), + }, + outputSchema: { + echo: z.string(), + }, + }, + async ({ message }) => { + const output = { echo: message } + return { + content: [ + { + type: 'text', + text: message, + }, + ], + structuredContent: output, + } + } + ) + + // Register calculator tool + server.registerTool( + 'calculator', + { + title: 'Calculator Tool', + description: 'Performs basic arithmetic operations', + inputSchema: { + operation: z.enum(['add', 'subtract', 'multiply', 'divide']), + a: z.number(), + b: z.number(), + }, + outputSchema: { + result: z.number(), + }, + }, + async ({ operation, a, b }) => { + let result: number + + switch (operation) { + case 'add': + result = a + b + break + case 'subtract': + result = a - b + break + case 'multiply': + result = a * b + break + case 'divide': + if (b === 0) { + throw new Error('Division by zero') + } + result = a / b + break + } + + const output = { result } + return { + content: [ + { + type: 'text', + text: `Result: ${result}`, + }, + ], + structuredContent: output, + } + } + ) + + // Register confirm_action tool (tests elicitation) + server.registerTool( + 'confirm_action', + { + title: 'Confirm Action Tool', + description: 'Asks the user to confirm before proceeding. Use this tool when you need user confirmation.', + inputSchema: { + action: z.string(), + }, + }, + async ({ action }) => { + const result = await server.server.elicitInput({ + message: `Do you want to proceed with: ${action}?`, + requestedSchema: { + type: 'object', + properties: { + confirmed: { type: 'boolean', description: 'Whether the user confirms' }, + }, + }, + }) + + if (result.action === 'accept') { + return { + content: [{ type: 'text', text: `Action "${action}" confirmed by user` }], + } + } + + return { + content: [{ type: 'text', text: `Action "${action}" was ${result.action}d by user` }], + } + } + ) + + // Register error tool + server.registerTool( + 'error_tool', + { + title: 'Error Tool', + description: 'Intentionally throws an error for testing error handling', + inputSchema: { + error_message: z.string().optional(), + }, + outputSchema: { + error: z.string(), + }, + }, + async ({ error_message }) => { + const message = error_message || 'Intentional error' + throw new Error(message) + } + ) + + return server +} + +/** + * Interface for HTTP-based server info + */ +export interface HttpServerInfo { + server: HttpServer + port: number + url: string + close: () => Promise +} + +/** + * Creates and starts a Streamable HTTP MCP server on a random port. + * Uses stateless mode - creates a new transport for each request. + */ +export async function startHTTPServer(): Promise { + const mcpServer = createTestServer() + + const httpServer = createServer(async (req: IncomingMessage, res: ServerResponse) => { + if (req.url === '/mcp' && req.method === 'POST') { + try { + // Read request body + let body = '' + await new Promise((resolve) => { + req.on('data', (chunk) => { + body += chunk.toString() + }) + req.on('end', () => { + resolve() + }) + }) + + const parsedBody = body ? JSON.parse(body) : undefined + + // Create a new transport for each request (stateless mode) + const transport = new StreamableHTTPServerTransport({ + enableJsonResponse: true, + }) + + res.on('close', async () => { + await transport.close() + }) + + await mcpServer.connect(transport as Transport) + await transport.handleRequest(req, res, parsedBody) + } catch (error) { + console.error('Error handling MCP request:', error) + if (!res.headersSent) { + res.writeHead(500, { 'Content-Type': 'application/json' }) + res.end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32603, + message: 'Internal server error', + }, + id: null, + }) + ) + } + } + } else { + res.writeHead(404) + res.end() + } + }) + + return new Promise((resolve) => { + httpServer.listen(0, () => { + const address = httpServer.address() as AddressInfo + const port = address.port + const url = `http://localhost:${port}/mcp` + + resolve({ + server: httpServer, + port, + url, + close: async () => { + return new Promise((resolveClose) => { + httpServer.close(() => { + resolveClose() + }) + }) + }, + }) + }) + }) +} + +// Start the stdio server when this file is run directly +if (import.meta.url === `file://${process.argv[1]}`) { + const server = createTestServer() + const transport = new StdioServerTransport() + await server.connect(transport) +} diff --git a/strands-ts/test/integ/__fixtures__/test-mcp-task-server.ts b/strands-ts/test/integ/__fixtures__/test-mcp-task-server.ts new file mode 100644 index 0000000000..78e7d476ff --- /dev/null +++ b/strands-ts/test/integ/__fixtures__/test-mcp-task-server.ts @@ -0,0 +1,387 @@ +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js' +import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js' +import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js' +import { InMemoryTaskStore } from '@modelcontextprotocol/sdk/experimental/tasks/stores/in-memory.js' +import type { TaskStore, CreateTaskOptions } from '@modelcontextprotocol/sdk/experimental/tasks/interfaces.js' +import type { CallToolResult, Task, Request, RequestId, Result } from '@modelcontextprotocol/sdk/types.js' +import { createServer, type Server as HttpServer } from 'node:http' +import type { AddressInfo } from 'node:net' +import type { IncomingMessage, ServerResponse } from 'node:http' +import * as z from 'zod' + +/** Context stored with long_running_task */ +interface LongRunningContext extends Record { + type: 'long_running' + startTime: number + duration: number + message: string +} + +/** Context stored with instant_task */ +interface InstantContext extends Record { + type: 'instant' + value: string +} + +/** Context stored with failing_task */ +interface FailingContext extends Record { + type: 'failing' + startTime: number + errorMessage: string +} + +type TaskContext = LongRunningContext | InstantContext | FailingContext + +/** + * Calculate task status message based on progress for long_running_task + */ +function getProgressMessage(elapsed: number, duration: number): string { + const progress = elapsed / duration + if (progress < 0.33) return 'Step 1: Initializing...' + if (progress < 0.66) return 'Step 2: Processing...' + return 'Step 3: Finalizing...' +} + +/** + * Custom TaskStore that computes task status statelessly on getTask calls. + * + * This works around two issues in the MCP SDK: + * 1. Custom getTask/getTaskResult handlers registered via registerToolTask are bypassed + * by the Protocol class. See: https://github.com/modelcontextprotocol/typescript-sdk/pull/1335 + * 2. InMemoryTaskStore doesn't store the `context` field from CreateTaskOptions + * + * By storing context ourselves and computing status in getTask(), we ensure proper behavior. + */ +class StatelessTaskStore implements TaskStore { + private _delegate: InMemoryTaskStore + private _contexts: Map = new Map() + + constructor() { + this._delegate = new InMemoryTaskStore() + } + + cleanup(): void { + this._delegate.cleanup() + this._contexts.clear() + } + + async createTask( + taskParams: CreateTaskOptions, + requestId: RequestId, + request: Request, + sessionId?: string + ): Promise { + const task = await this._delegate.createTask(taskParams, requestId, request, sessionId) + // Store context separately since InMemoryTaskStore doesn't store it + if (taskParams.context) { + this._contexts.set(task.taskId, taskParams.context as TaskContext) + } + return task + } + + async updateTaskStatus( + taskId: string, + status: Task['status'], + statusMessage?: string, + sessionId?: string + ): Promise { + return this._delegate.updateTaskStatus(taskId, status, statusMessage, sessionId) + } + + async storeTaskResult( + taskId: string, + status: 'completed' | 'failed', + result: Result, + sessionId?: string + ): Promise { + return this._delegate.storeTaskResult(taskId, status, result, sessionId) + } + + async getTaskResult(taskId: string, sessionId?: string): Promise { + // First compute the status (which may complete the task) + await this.getTask(taskId, sessionId) + return this._delegate.getTaskResult(taskId, sessionId) + } + + async listTasks(cursor?: string, sessionId?: string): Promise<{ tasks: Task[]; nextCursor?: string }> { + return this._delegate.listTasks(cursor, sessionId) + } + + /** + * Override getTask to compute status from elapsed time for time-based tasks. + */ + async getTask(taskId: string, sessionId?: string): Promise { + const task = await this._delegate.getTask(taskId, sessionId) + if (!task) return task + + // Get context from our separate store + const ctx = this._contexts.get(taskId) + if (!ctx) return task + + // Handle long_running_task: calculate status from elapsed time + if (ctx.type === 'long_running') { + const elapsed = Date.now() - ctx.startTime + if (elapsed >= ctx.duration) { + // Task is done - mark completed + if (task.status !== 'completed') { + await this._delegate.storeTaskResult(taskId, 'completed', { + content: [{ type: 'text', text: ctx.message }], + }) + } + } else { + // Still working - update status message + await this._delegate.updateTaskStatus(taskId, 'working', getProgressMessage(elapsed, ctx.duration)) + } + return await this._delegate.getTask(taskId, sessionId) + } + + // Handle failing_task: fail after delay + if (ctx.type === 'failing') { + const elapsed = Date.now() - ctx.startTime + const failDelay = 60 // ms before failing + + if (elapsed >= failDelay) { + // Time to fail + if (task.status !== 'failed') { + await this._delegate.storeTaskResult(taskId, 'failed', { + content: [{ type: 'text', text: ctx.errorMessage }], + isError: true, + }) + } + } else { + // Still "working" before failure + await this._delegate.updateTaskStatus(taskId, 'working', 'About to fail...') + } + return await this._delegate.getTask(taskId, sessionId) + } + + // instant_task and others: no special handling needed + return task + } +} + +/** + * Creates a test MCP server with task-enabled tools using the high-level API. + * + * Note: Due to an MCP SDK bug (https://github.com/modelcontextprotocol/typescript-sdk/pull/1335), + * custom getTask/getTaskResult handlers are bypassed. Status calculation is done in + * StatelessTaskStore.getTask() instead. + */ +function createTaskTestServer(taskStore: StatelessTaskStore): McpServer { + const server = new McpServer( + { name: 'test-mcp-task-server', version: '1.0.0' }, + { + capabilities: { + tools: {}, + tasks: { + requests: { + tools: { call: {} }, + }, + }, + }, + taskStore, + } + ) + + // Register long_running_task - stores context with timing info + // Status calculation happens in StatelessTaskStore.getTask() + server.experimental.tasks.registerToolTask( + 'long_running_task', + { + description: 'Simulates a long-running task with progress updates', + inputSchema: { + duration: z.number().optional().describe('Duration in milliseconds (default: 200)'), + message: z.string().optional().describe('Message to include in result'), + }, + }, + { + async createTask({ duration, message }, { taskStore: store }) { + const context: LongRunningContext = { + type: 'long_running', + startTime: Date.now(), + duration: duration ?? 200, + message: message ?? 'Task completed!', + } + const task = await store.createTask({ ttl: 60000, pollInterval: 50, context }) + return { task } + }, + + async getTask(_args, { taskId, taskStore: store }) { + return await store.getTask(taskId) + }, + + async getTaskResult(_args, { taskId, taskStore: store }) { + const result = await store.getTaskResult(taskId) + return result as CallToolResult + }, + } + ) + + // Register instant_task - completes immediately on creation + server.experimental.tasks.registerToolTask( + 'instant_task', + { + description: 'A task that completes immediately', + inputSchema: { + value: z.string().optional().describe('Value to return'), + }, + }, + { + async createTask({ value }, { taskStore: store }) { + const context: InstantContext = { + type: 'instant', + value: value ?? 'instant result', + } + const task = await store.createTask({ ttl: 60000, pollInterval: 50, context }) + // Complete immediately + await store.storeTaskResult(task.taskId, 'completed', { + content: [{ type: 'text', text: context.value }], + }) + return { task } + }, + + async getTask(_args, { taskId, taskStore: store }) { + return await store.getTask(taskId) + }, + + async getTaskResult(_args, { taskId, taskStore: store }) { + const result = await store.getTaskResult(taskId) + return result as CallToolResult + }, + } + ) + + // Register failing_task - stores context with timing info + // Failure logic happens in StatelessTaskStore.getTask() + server.experimental.tasks.registerToolTask( + 'failing_task', + { + description: 'A task that always fails for error handling testing', + inputSchema: { + error_message: z.string().optional().describe('Error message to return'), + }, + }, + { + async createTask({ error_message }, { taskStore: store }) { + const context: FailingContext = { + type: 'failing', + startTime: Date.now(), + errorMessage: error_message ?? 'Task intentionally failed', + } + const task = await store.createTask({ ttl: 60000, pollInterval: 50, context }) + return { task } + }, + + async getTask(_args, { taskId, taskStore: store }) { + return await store.getTask(taskId) + }, + + async getTaskResult(_args, { taskId, taskStore: store }) { + const result = await store.getTaskResult(taskId) + return result as CallToolResult + }, + } + ) + + return server +} + +/** + * Interface for HTTP-based server info + */ +export interface TaskHttpServerInfo { + server: HttpServer + port: number + url: string + close: () => Promise +} + +/** + * Creates and starts a task-enabled Streamable HTTP MCP server on a random port. + * Creates a new McpServer per request while sharing the taskStore. + */ +export async function startTaskHTTPServer(): Promise { + const taskStore = new StatelessTaskStore() + + const httpServer = createServer(async (req: IncomingMessage, res: ServerResponse) => { + if (req.url === '/mcp' && req.method === 'POST') { + try { + // Read request body + let body = '' + await new Promise((resolve) => { + req.on('data', (chunk) => { + body += chunk.toString() + }) + req.on('end', () => { + resolve() + }) + }) + + const parsedBody = body ? JSON.parse(body) : undefined + + // Create a new server and transport for each request + // The taskStore is shared across all requests to persist task state + const mcpServer = createTaskTestServer(taskStore) + const transport = new StreamableHTTPServerTransport({ + enableJsonResponse: true, + }) + + res.on('close', async () => { + await transport.close() + }) + + // @ts-expect-error - MCP SDK doesn't support exactOptionalPropertyTypes + await mcpServer.connect(transport) + await transport.handleRequest(req, res, parsedBody) + } catch (error) { + console.error('Error handling MCP request:', error) + if (!res.headersSent) { + res.writeHead(500, { 'Content-Type': 'application/json' }) + res.end( + JSON.stringify({ + jsonrpc: '2.0', + error: { + code: -32603, + message: 'Internal server error', + }, + id: null, + }) + ) + } + } + } else { + res.writeHead(404) + res.end() + } + }) + + return new Promise((resolve) => { + httpServer.listen(0, () => { + const address = httpServer.address() as AddressInfo + const port = address.port + const url = `http://localhost:${port}/mcp` + + resolve({ + server: httpServer, + port, + url, + close: async () => { + taskStore.cleanup() + return new Promise((resolveClose) => { + httpServer.close(() => { + resolveClose() + }) + }) + }, + }) + }) + }) +} + +// Start the stdio server when this file is run directly +if (import.meta.url === `file://${process.argv[1]}`) { + const taskStore = new StatelessTaskStore() + const server = createTaskTestServer(taskStore) + const transport = new StdioServerTransport() + await server.connect(transport) +} diff --git a/strands-ts/test/integ/__resources__/letter.pdf b/strands-ts/test/integ/__resources__/letter.pdf new file mode 100644 index 0000000000..d8c59f7492 Binary files /dev/null and b/strands-ts/test/integ/__resources__/letter.pdf differ diff --git a/strands-ts/test/integ/__resources__/yellow.mp4 b/strands-ts/test/integ/__resources__/yellow.mp4 new file mode 100644 index 0000000000..0e5c55735d Binary files /dev/null and b/strands-ts/test/integ/__resources__/yellow.mp4 differ diff --git a/strands-ts/test/integ/__resources__/yellow.png b/strands-ts/test/integ/__resources__/yellow.png new file mode 100644 index 0000000000..9caac13bed Binary files /dev/null and b/strands-ts/test/integ/__resources__/yellow.png differ diff --git a/strands-ts/test/integ/a2a/a2a-agent.test.ts b/strands-ts/test/integ/a2a/a2a-agent.test.ts new file mode 100644 index 0000000000..4eafa34ac6 --- /dev/null +++ b/strands-ts/test/integ/a2a/a2a-agent.test.ts @@ -0,0 +1,48 @@ +import { describe, expect, it, inject, beforeAll } from 'vitest' +import { A2AAgent, A2AStreamUpdateEvent } from '$/sdk/a2a/index.js' +import { collectGenerator } from '$/sdk/__fixtures__/model-test-helpers.js' + +const a2aServer = { + get skip() { + return inject('a2a-server').shouldSkip + }, + get url() { + const url = inject('a2a-server').url + if (!url) throw new Error('A2A server URL not provided') + return url + }, +} + +describe.skipIf(a2aServer.skip)('A2AAgent', () => { + let agent: A2AAgent + + beforeAll(() => { + agent = new A2AAgent({ url: a2aServer.url }) + }) + + describe('invoke', () => { + it('receives a text response and populates agent card metadata', async () => { + const result = await agent.invoke('What is 2+2? Reply with just the number.') + + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.role).toBe('assistant') + expect(result.lastMessage.content.length).toBeGreaterThan(0) + expect(result.toString()).toMatch(/4/) + + expect(agent.name).toBe('Test A2A Agent') + expect(agent.description).toBe('Integration test agent') + }) + }) + + describe('stream', () => { + it('yields events and returns final result', async () => { + const { items, result } = await collectGenerator(agent.stream('Say the word test')) + + const streamUpdates = items.filter((e) => e instanceof A2AStreamUpdateEvent) + expect(streamUpdates.length).toBeGreaterThan(0) + + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.content[0]!.type).toBe('textBlock') + }) + }) +}) diff --git a/strands-ts/test/integ/a2a/express-server.test.node.ts b/strands-ts/test/integ/a2a/express-server.test.node.ts new file mode 100644 index 0000000000..50c90a65aa --- /dev/null +++ b/strands-ts/test/integ/a2a/express-server.test.node.ts @@ -0,0 +1,172 @@ +import { describe, expect, it, afterAll, beforeAll, afterEach } from 'vitest' +import { readFile } from 'node:fs/promises' +import { join } from 'node:path' +import type { Server } from 'node:http' +import type { AddressInfo } from 'node:net' +import type { Task } from '@a2a-js/sdk' +import express from 'express' +import { ClientFactory } from '@a2a-js/sdk/client' +import { Agent } from '@strands-agents/sdk' +import { A2AAgent, A2AStreamUpdateEvent, A2AResultEvent } from '$/sdk/a2a/index.js' +import { A2AExpressServer } from '$/sdk/a2a/express-server.js' +import { TextBlock } from '$/sdk/types/messages.js' +import { encodeBase64 } from '$/sdk/types/media.js' +import { collectGenerator } from '$/sdk/__fixtures__/model-test-helpers.js' +import { bedrock } from '../__fixtures__/model-providers.js' + +describe.skipIf(bedrock.skip)('A2AExpressServer', () => { + describe('serve', () => { + let a2aServer: A2AExpressServer + let abortController: AbortController + + beforeAll(async () => { + const agent = new Agent({ + model: bedrock.createModel(), + printer: false, + systemPrompt: 'You are a helpful assistant. Always respond in a single short sentence.', + }) + + a2aServer = new A2AExpressServer({ + agent, + name: 'Test A2A Agent', + description: 'Integration test agent', + port: 0, + }) + + abortController = new AbortController() + await a2aServer.serve({ signal: abortController.signal }) + }) + + afterAll(() => { + abortController?.abort() + }) + + it('serves agent card at well-known endpoint', async () => { + const factory = new ClientFactory() + const client = await factory.createFromUrl(`http://127.0.0.1:${a2aServer.port}`) + const card = await client.getAgentCard() + + expect(card.name).toBe('Test A2A Agent') + expect(card.description).toBe('Integration test agent') + expect(card.capabilities?.streaming).toBe(true) + }) + + it('processes an image sent as a file part', async () => { + const imagePath = join(process.cwd(), 'test/integ/__resources__/yellow.png') + const imageBytes = new Uint8Array(await readFile(imagePath)) + + const factory = new ClientFactory() + const rawClient = await factory.createFromUrl(`http://127.0.0.1:${a2aServer.port}`) + + const result = (await rawClient.sendMessage({ + message: { + kind: 'message', + messageId: globalThis.crypto.randomUUID(), + role: 'user', + parts: [ + { + kind: 'file', + file: { bytes: encodeBase64(imageBytes), mimeType: 'image/png' }, + }, + { kind: 'text', text: 'What color is this image? Reply with just the color name.' }, + ], + }, + })) as Task + + expect(result.kind).toBe('task') + expect(result.status.state).toBe('completed') + + const texts = result + .artifacts!.flatMap((a) => a.parts) + .filter((p) => p.kind === 'text') + .map((p) => (p as { kind: 'text'; text: string }).text) + .join('') + + expect(texts.toLowerCase()).toContain('yellow') + }) + }) + + describe('createMiddleware', () => { + const servers: Server[] = [] + + afterEach(() => { + for (const server of servers) { + server.close() + } + servers.length = 0 + }) + + /** + * Starts an A2A server on an OS-assigned port and returns the URL. + */ + async function startServer(agent: Agent): Promise<{ url: string }> { + return new Promise((resolve, reject) => { + const app = express() + const server = app.listen(0, 'localhost', () => { + const { port } = server.address() as AddressInfo + servers.push(server) + + const url = `http://localhost:${port}` + const a2aServer = new A2AExpressServer({ + agent, + name: 'Test Agent', + description: 'Agent for A2A integration tests', + httpUrl: url, + }) + app.use(a2aServer.createMiddleware()) + + resolve({ url }) + }) + server.on('error', reject) + }) + } + + it('invoke returns AgentResult with response text', async () => { + const agent = new Agent({ + model: bedrock.createModel({ maxTokens: 256 }), + printer: false, + systemPrompt: 'Respond with exactly one word: "pong".', + }) + + const { url } = await startServer(agent) + const remoteAgent = new A2AAgent({ url }) + + const result = await remoteAgent.invoke('ping') + + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.role).toBe('assistant') + expect(result.lastMessage.content).toHaveLength(1) + expect(result.lastMessage.content[0]).toBeInstanceOf(TextBlock) + expect((result.lastMessage.content[0] as TextBlock).text.toLowerCase()).toContain('pong') + }) + + it('stream yields A2AStreamUpdateEvents and A2AResultEvent', async () => { + const agent = new Agent({ + model: bedrock.createModel({ maxTokens: 256 }), + printer: false, + systemPrompt: 'Respond with exactly one word: "pong".', + }) + + const { url } = await startServer(agent) + const remoteAgent = new A2AAgent({ url }) + + const { items, result } = await collectGenerator(remoteAgent.stream('ping')) + + const streamUpdates = items.filter((e) => e instanceof A2AStreamUpdateEvent) + const resultEvents = items.filter((e) => e instanceof A2AResultEvent) + + expect(streamUpdates.length).toBeGreaterThan(0) + expect(resultEvents).toHaveLength(1) + + for (const update of streamUpdates) { + expect(['message', 'task', 'status-update', 'artifact-update']).toContain( + (update as A2AStreamUpdateEvent).event.kind + ) + } + + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.role).toBe('assistant') + expect((result.lastMessage.content[0] as TextBlock).text.toLowerCase()).toContain('pong') + }) + }) +}) diff --git a/strands-ts/test/integ/agent-as-tool.test.ts b/strands-ts/test/integ/agent-as-tool.test.ts new file mode 100644 index 0000000000..b58df692fc --- /dev/null +++ b/strands-ts/test/integ/agent-as-tool.test.ts @@ -0,0 +1,34 @@ +import { describe, expect, it } from 'vitest' +import { Agent, tool } from '@strands-agents/sdk' +import { z } from 'zod' +import { bedrock } from './__fixtures__/model-providers.js' + +const getTigerHeight = tool({ + name: 'get_tiger_height', + description: 'Returns the height of a tiger in centimeters', + inputSchema: z.object({}), + callback: async () => 100, +}) + +describe.skipIf(bedrock.skip)('AgentAsTool (integration)', () => { + it('parent agent invokes a sub-agent tool that uses a standard tool and gets a result', async () => { + const innerAgent = new Agent({ + model: bedrock.createModel({ maxTokens: 500 }), + name: 'tiger_expert', + description: 'An agent knowledgeable about tigers', + tools: [getTigerHeight], + printer: false, + }) + + const outerAgent = new Agent({ + model: bedrock.createModel({ maxTokens: 500 }), + tools: [innerAgent.asTool()], + printer: false, + }) + + const result = await outerAgent.invoke('Ask the tiger_expert about the height of tigers.') + + expect(result.stopReason).toBe('endTurn') + expect(result.metrics?.toolMetrics['tiger_expert']?.successCount).toBeGreaterThanOrEqual(1) + }) +}) diff --git a/strands-ts/test/integ/agent.cancel.test.ts b/strands-ts/test/integ/agent.cancel.test.ts new file mode 100644 index 0000000000..2c0168fdfa --- /dev/null +++ b/strands-ts/test/integ/agent.cancel.test.ts @@ -0,0 +1,125 @@ +import { describe, expect, it } from 'vitest' +import { Agent, tool } from '@strands-agents/sdk' +import { z } from 'zod' + +import { allProviders } from './__fixtures__/model-providers.js' + +describe.each(allProviders)('Cancellation with $name', ({ name, skip, createModel, supports }) => { + describe.skipIf(skip)(`${name} Cancellation`, () => { + it('cancels during model streaming', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Write a very long story about a dragon.', + }) + + let streamEventsReceived = 0 + for await (const event of agent.stream('Begin')) { + if (event.type === 'modelStreamUpdateEvent') { + streamEventsReceived++ + if (streamEventsReceived === 1) { + agent.cancel() + } + } + } + + expect(streamEventsReceived).toBeGreaterThanOrEqual(1) + + // Messages should be in a valid, reinvokable state + const lastMessage = agent.messages[agent.messages.length - 1]! + expect(lastMessage.role).toBe('assistant') + }) + + it.skipIf(!supports.tools)('cancels before tool execution', async () => { + let toolExecuted = false + const trackedCalculator = tool({ + name: 'calculator', + description: 'Performs basic arithmetic operations. Always use this tool for math.', + inputSchema: z.object({ + operation: z.enum(['add', 'subtract', 'multiply', 'divide']), + a: z.number(), + b: z.number(), + }), + callback: async ({ operation, a, b }) => { + toolExecuted = true + const ops = { add: a + b, subtract: a - b, multiply: a * b, divide: a / b } + return `Result: ${ops[operation]}` + }, + }) + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the calculator tool for all math. Do not attempt mental math.', + tools: [trackedCalculator], + }) + + for await (const event of agent.stream('What is 999 * 111?')) { + if (event.type === 'modelMessageEvent' && event.stopReason === 'toolUse') { + agent.cancel() + } + } + + expect(toolExecuted).toBe(false) + + // Messages should include the assistant's tool use and cancellation tool results + const toolUseMsg = agent.messages.find((m) => m.content.some((b) => b.type === 'toolUseBlock')) + expect(toolUseMsg).toBeDefined() + const toolResultMsg = agent.messages.find((m) => + m.content.some((b) => b.type === 'toolResultBlock' && b.status === 'error') + ) + expect(toolResultMsg).toBeDefined() + }) + + it('cancels from a timer using agent.cancel()', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Write an extremely long and detailed story. Never stop writing.', + }) + + // Cancel after a short delay — simulates a timeout or external trigger + globalThis.setTimeout(() => agent.cancel(), 500) + + const result = await agent.invoke('Write a 10000 word story') + + expect(result.stopReason).toBe('cancelled') + }) + + it('cancels via AbortSignal.timeout()', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Write an extremely long and detailed story. Never stop writing.', + }) + + const result = await agent.invoke('Write a 10000 word story', { + cancelSignal: AbortSignal.timeout(500), + }) + + expect(result.stopReason).toBe('cancelled') + }) + + it('allows reuse after cancellation', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + }) + + // First invocation: cancel during streaming + for await (const event of agent.stream('Write a very long story')) { + if (event.type === 'modelStreamUpdateEvent') { + agent.cancel() + break + } + } + + const lastMessage = agent.messages[agent.messages.length - 1]! + expect(lastMessage.role).toBe('assistant') + + // Second invocation: should succeed normally + const result = await agent.invoke('Say the word "pineapple"') + expect(result.stopReason).toBe('endTurn') + }) + }) +}) diff --git a/strands-ts/test/integ/agent.test.ts b/strands-ts/test/integ/agent.test.ts new file mode 100644 index 0000000000..076dc402ea --- /dev/null +++ b/strands-ts/test/integ/agent.test.ts @@ -0,0 +1,611 @@ +import { describe, expect, it } from 'vitest' +import { + Agent, + CitationsBlock, + DocumentBlock, + ImageBlock, + Message, + TextBlock, + ToolUseBlock, + VideoBlock, + tool, +} from '@strands-agents/sdk' +import { notebook } from '@strands-agents/sdk/vended-tools/notebook' +import { httpRequest } from '@strands-agents/sdk/vended-tools/http-request' +import { z } from 'zod' + +import { collectGenerator } from '$/sdk/__fixtures__/model-test-helpers.js' +import { loadFixture } from './__fixtures__/test-helpers.js' +// Import fixtures using Vite's ?url suffix +import yellowMp4Url from './__resources__/yellow.mp4?url' +import yellowPngUrl from './__resources__/yellow.png?url' +import letterPdfUrl from './__resources__/letter.pdf?url' +import { allProviders } from './__fixtures__/model-providers.js' + +// Calculator tool using Zod schema +const calculatorTool = tool({ + name: 'calculator', + description: 'Performs basic arithmetic operations', + inputSchema: z.object({ + operation: z.enum(['add', 'subtract', 'multiply', 'divide']), + a: z.number(), + b: z.number(), + }), + callback: async ({ operation, a, b }) => { + const ops = { + add: a + b, + subtract: a - b, + multiply: a * b, + divide: a / b, + } + return `Result: ${ops[operation]}` + }, +}) + +// Calculator tool using JSON schema +const jsonCalculatorTool = tool({ + name: 'calculator', + description: 'Performs basic arithmetic operations', + inputSchema: { + type: 'object', + properties: { + operation: { type: 'string', enum: ['add', 'subtract', 'multiply', 'divide'] }, + a: { type: 'number' }, + b: { type: 'number' }, + }, + required: ['operation', 'a', 'b'], + }, + callback: async (input) => { + const { operation, a, b } = input as { operation: 'add' | 'subtract' | 'multiply' | 'divide'; a: number; b: number } + const ops = { + add: a + b, + subtract: a - b, + multiply: a * b, + divide: a / b, + } + return `Result: ${ops[operation]}` + }, +}) + +describe.each(allProviders)('Agent with $name', ({ name, skip, createModel, models, supports }) => { + describe.skipIf(skip)(`${name} Integration Tests`, () => { + describe('Basic Functionality', () => { + it.skipIf(!supports.tools)('handles invocation, streaming, system prompts, and tool use', async () => { + // Test basic invocation with system prompt and tool + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the calculator tool to solve math problems. Respond with only the numeric result.', + tools: [calculatorTool], + }) + + // Test streaming with event collection + const { items, result } = await collectGenerator(agent.stream('What is 123 * 456?')) + + // Verify high-level agent events are yielded + expect(items.some((item) => item.type === 'beforeInvocationEvent')).toBe(true) + + // Verify result structure and stop reason + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.role).toBe('assistant') + expect(result.lastMessage.content.length).toBeGreaterThan(0) + + // Verify tool was used by checking message history + const toolUseMessage = agent.messages.find((msg) => msg.content.some((block) => block.type === 'toolUseBlock')) + expect(toolUseMessage).toBeDefined() + + // Verify final response contains the result (123 * 456 = 56088) + const textContent = result.lastMessage.content.find((block) => block.type === 'textBlock') + expect(textContent).toBeDefined() + expect(textContent?.text).toMatch(/56088/) + + // Validate multi-turn works after tool use + await collectGenerator(agent.stream('What was the result?')) + }) + + it.skipIf(!supports.tools)('handles tool use with JSON schema tool', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the calculator tool to solve math problems. Respond with only the numeric result.', + tools: [jsonCalculatorTool], + }) + + const result = await agent.invoke('What is 25 * 48?') + + expect(result.stopReason).toBe('endTurn') + + const toolUseMessage = agent.messages.find((msg) => msg.content.some((block) => block.type === 'toolUseBlock')) + expect(toolUseMessage).toBeDefined() + + const textContent = result.lastMessage.content.find((block) => block.type === 'textBlock') + expect(textContent).toBeDefined() + expect(textContent?.text).toMatch(/1200/) + }) + + it('yields metadata events through the agent stream', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Respond with a brief greeting.', + }) + + // Test streaming with event collection + const { items, result } = await collectGenerator(agent.stream('Say hello')) + + // Verify metadata event is yielded through the agent (wrapped in ModelStreamUpdateEvent) + const updateEvent = items.find( + (item) => item.type === 'modelStreamUpdateEvent' && item.event.type === 'modelMetadataEvent' + ) + expect(updateEvent).toBeDefined() + if (updateEvent?.type !== 'modelStreamUpdateEvent' || updateEvent.event.type !== 'modelMetadataEvent') { + throw new Error('Expected modelStreamUpdateEvent wrapping modelMetadataEvent') + } + const metadataEvent = updateEvent.event + expect(metadataEvent.usage).toBeDefined() + expect(metadataEvent.usage?.inputTokens).toBeGreaterThan(0) + expect(metadataEvent.usage?.outputTokens).toBeGreaterThan(0) + + // Bedrock includes latencyMs in metrics, OpenAI does not + if (name === 'BedrockModel') { + expect(metadataEvent.metrics?.latencyMs).toBeGreaterThan(0) + } + + // Verify result structure + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.role).toBe('assistant') + expect(result.lastMessage.content.length).toBeGreaterThan(0) + }) + }) + + describe('Multi-turn Conversations', () => { + it('maintains message history and conversation context', async () => { + const agent = new Agent({ model: createModel(), printer: false }) + + // First turn + await agent.invoke('My name is Alice') + expect(agent.messages).toHaveLength(2) // user + assistant + + // Second turn + await agent.invoke('What is my name?') + expect(agent.messages).toHaveLength(4) // 2 user + 2 assistant + + // Verify message ordering + expect(agent.messages[0]?.role).toBe('user') + expect(agent.messages[1]?.role).toBe('assistant') + expect(agent.messages[2]?.role).toBe('user') + expect(agent.messages[3]?.role).toBe('assistant') + + // Verify conversation context is preserved + const lastMessage = agent.messages[agent.messages.length - 1] + const textContent = lastMessage?.content.find((block) => block.type === 'textBlock') + expect(textContent?.text).toMatch(/Alice/i) + }) + }) + + describe.skipIf(!supports.images || !supports.documents)('Media Blocks', () => { + it('handles multiple media blocks in single request', async () => { + // Create document block + const docBlock = new DocumentBlock({ + name: 'test-document', + format: 'txt', + source: { text: 'The document contains the word ZEBRA.' }, + }) + + // Create image block + const imageBytes = await loadFixture(yellowPngUrl) + const imageBlock = new ImageBlock({ + format: 'png', + source: { bytes: imageBytes }, + }) + + // Initialize agent with messages array containing Message instance + // Note: Bedrock requires a text block when using documents + const agent = new Agent({ + model: createModel(), + messages: [ + new Message({ + role: 'user', + content: [ + docBlock, + imageBlock, + new TextBlock( + 'I shared a document and an image. What animal is in the document and what color is the image? Answer briefly.' + ), + ], + }), + ], + printer: false, + }) + + const result = await agent.invoke([]) + + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.role).toBe('assistant') + + // Response should reference both the document content and image color + const textContent = result.lastMessage.content.find((block) => block.type === 'textBlock') + expect(textContent).toBeDefined() + expect(textContent?.text).toMatch(/zebra/i) + }) + + it('processes PDF document input correctly', async () => { + const pdfBytes = await loadFixture(letterPdfUrl) + + const agent = new Agent({ + model: createModel(), + messages: [ + new Message({ + role: 'user', + content: [ + new DocumentBlock({ + name: 'letter', + format: 'pdf', + source: { bytes: pdfBytes }, + }), + new TextBlock('Summarize this document briefly.'), + ], + }), + ], + printer: false, + }) + + const result = await agent.invoke([]) + + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.role).toBe('assistant') + + const textContent = result.lastMessage.content.find((block) => block.type === 'textBlock') + expect(textContent).toBeDefined() + expect(textContent?.text.length).toBeGreaterThan(10) + }) + }) + + it.skipIf(!supports.documents)('handles document input', async () => { + const docBlock = new DocumentBlock({ + name: 'test-document', + format: 'txt', + source: { text: 'The secret code word is ELEPHANT.' }, + }) + + const agent = new Agent({ + model: createModel(), + printer: false, + }) + + const result = await agent.invoke([ + new TextBlock('What is the secret code word in the document? Answer in one word.'), + docBlock, + ]) + + expect(result.stopReason).toBe('endTurn') + const textContent = result.lastMessage.content.find((block) => block.type === 'textBlock') + expect(textContent).toBeDefined() + expect(textContent?.text).toMatch(/elephant/i) + }) + + it.skipIf(!supports.video)('handles video input', async () => { + const videoBytes = await loadFixture(yellowMp4Url) + const videoBlock = new VideoBlock({ + format: 'mp4', + source: { bytes: videoBytes }, + }) + + const agent = new Agent({ + model: createModel(models.video), + printer: false, + }) + + const result = await agent.invoke([ + new TextBlock('What color is shown in this video? Answer in one word.'), + videoBlock, + ]) + + expect(result.stopReason).toBe('endTurn') + const textContent = result.lastMessage.content.find((block) => block.type === 'textBlock') + expect(textContent).toBeDefined() + }) + + describe.skipIf(!supports.citations)('Citations', () => { + const documentText = [ + 'France is a country in Western Europe. Its capital is Paris, which is known as the City of Light.', + 'Paris has a population of approximately 2.1 million people in the city proper.', + 'The Eiffel Tower, built in 1889, is the most visited paid monument in the world.', + 'France is the most visited country in the world, with over 89 million tourists annually.', + 'The French Revolution of 1789 was a pivotal event in world history.', + ].join(' ') + + const textDocBlock = new DocumentBlock({ + name: 'test-document', + format: 'txt', + source: { content: [{ text: documentText }] }, + citations: { enabled: true }, + }) + + const textDocPrompt = new TextBlock( + 'Using the document, what is the capital of France and what is it known for? Cite specific details.' + ) + + it('returns documentChunk citations from text document', async () => { + const agent = new Agent({ + model: createModel({ stream: false }), + printer: false, + }) + + const result = await agent.invoke([textDocBlock, textDocPrompt]) + + expect(result.stopReason).toBe('endTurn') + + const citationsBlock = result.lastMessage.content.find( + (block): block is CitationsBlock => block.type === 'citationsBlock' + ) + expect(citationsBlock).toBeDefined() + expect(citationsBlock!.citations).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + location: expect.objectContaining({ type: 'documentChunk' }), + source: expect.any(String), + title: expect.any(String), + sourceContent: expect.arrayContaining([expect.objectContaining({ text: expect.any(String) })]), + }), + ]) + ) + expect(citationsBlock!.content).toEqual( + expect.arrayContaining([expect.objectContaining({ text: expect.any(String) })]) + ) + }) + + it('returns documentPage citations from PDF document and preserves them in multi-turn', async () => { + const pdfBytes = await loadFixture(letterPdfUrl) + + const agent = new Agent({ + model: createModel({ stream: false }), + printer: false, + }) + + const result = await agent.invoke([ + new DocumentBlock({ + name: 'letter', + format: 'pdf', + source: { bytes: pdfBytes }, + citations: { enabled: true }, + }), + new TextBlock('Summarize this document briefly.'), + ]) + + expect(result.stopReason).toBe('endTurn') + + const citationsBlock = result.lastMessage.content.find( + (block): block is CitationsBlock => block.type === 'citationsBlock' + ) + expect(citationsBlock).toBeDefined() + expect(citationsBlock!.citations).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + location: expect.objectContaining({ type: 'documentPage' }), + source: expect.any(String), + title: expect.any(String), + sourceContent: expect.arrayContaining([expect.objectContaining({ text: expect.any(String) })]), + }), + ]) + ) + expect(citationsBlock!.content).toEqual( + expect.arrayContaining([expect.objectContaining({ text: expect.any(String) })]) + ) + + // Second turn: verify citations survive in conversation history + const followUp = await agent.invoke('What else can you tell me about this document?') + expect(followUp.stopReason).toBe('endTurn') + expect(followUp.lastMessage.role).toBe('assistant') + expect(followUp.lastMessage.content.length).toBeGreaterThan(0) + }) + + it.each([ + { mode: 'non-streaming', stream: false as const }, + { mode: 'streaming', stream: true as const }, + ])('emits citationsDelta events in $mode mode', async ({ stream }) => { + const agent = new Agent({ + model: createModel({ stream }), + printer: false, + }) + + const { items, result } = await collectGenerator(agent.stream([textDocBlock, textDocPrompt])) + + expect(result.stopReason).toBe('endTurn') + + const citationDeltas = items.filter( + (item) => + item.type === 'modelStreamUpdateEvent' && + item.event.type === 'modelContentBlockDeltaEvent' && + item.event.delta.type === 'citationsDelta' + ) + expect(citationDeltas.length).toBeGreaterThan(0) + + const citationsBlock = result.lastMessage.content.find( + (block): block is CitationsBlock => block.type === 'citationsBlock' + ) + expect(citationsBlock).toBeDefined() + expect(citationsBlock!.citations.length).toBeGreaterThan(0) + }) + }) + + describe.skipIf(!supports.images)('multimodal input', () => { + it('accepts ContentBlock[] input', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + }) + + const yellowPng = await loadFixture(yellowPngUrl) + const imageBlock = new ImageBlock({ + format: 'png', + source: { bytes: yellowPng }, + }) + + const contentBlocks = [new TextBlock('What color is this image? Answer in one word.'), imageBlock] + + const result = await agent.invoke(contentBlocks) + + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.role).toBe('assistant') + + const textContent = result.lastMessage.content.find((block) => block.type === 'textBlock') + expect(textContent).toBeDefined() + }) + + it('accepts Message[] input for conversation history', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + }) + + const conversationHistory = [ + new Message({ + role: 'user', + content: [new TextBlock('Remember this number: 42')], + }), + new Message({ + role: 'assistant', + content: [new TextBlock('I will remember the number 42.')], + }), + new Message({ + role: 'user', + content: [new TextBlock('What number did I ask you to remember?')], + }), + ] + + const result = await agent.invoke(conversationHistory) + + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.role).toBe('assistant') + + const textContent = result.lastMessage.content.find((block) => block.type === 'textBlock') + expect(textContent).toBeDefined() + expect(textContent?.text).toMatch(/42/) + }) + }) + + it.skipIf(!supports.tools)('handles tool invocation', async () => { + const agent = new Agent({ + model: createModel(), + tools: [notebook, httpRequest], + printer: false, + }) + + await agent.invoke('Call Open-Meteo to get the weather in NYC, and take a note of what you did') + expect( + agent.messages.some((message) => + message.content.some((block) => block.type == 'toolUseBlock' && block.name == 'notebook') + ) + ).toBe(true) + expect( + agent.messages.some((message) => + message.content.some((block) => block.type == 'toolUseBlock' && block.name == 'http_request') + ) + ).toBe(true) + + // Validate multi-turn works after tool use + await collectGenerator(agent.stream('What was the result?')) + }) + + describe.skipIf(!supports.tools)('Structured Output', () => { + it('returns validated structured output', async () => { + const schema = z.object({ answer: z.number() }) + + const agent = new Agent({ + model: createModel(), + printer: false, + structuredOutputSchema: schema, + }) + + const result = await agent.invoke('What is 2 + 2?') + + expect(result.structuredOutput).toStrictEqual({ answer: 4 }) + }) + }) + + it.skipIf(!supports.reasoning)('emits reasoning content with thinking model', async () => { + const agent = new Agent({ + model: createModel(models.reasoning), + printer: false, + }) + + const { items, result } = await collectGenerator(agent.stream('What is 15 * 23? Think step by step.')) + + // Should have reasoning content deltas + const reasoningDeltas = items.filter( + (item) => + item.type === 'modelStreamUpdateEvent' && + item.event.type === 'modelContentBlockDeltaEvent' && + item.event.delta.type === 'reasoningContentDelta' + ) + expect(reasoningDeltas.length).toBeGreaterThan(0) + + // Should also have text content with the answer + expect(result.stopReason).toBe('endTurn') + const textContent = result.lastMessage.content.find((block) => block.type === 'textBlock') + expect(textContent).toBeDefined() + expect(textContent?.text).toContain('345') + }) + + it.skipIf(!supports.toolThinking)('handles tool use with thinking model', async () => { + const agent = new Agent({ + model: createModel(models.reasoning), + printer: false, + systemPrompt: 'Use the calculator tool to solve math problems. Respond with only the numeric result.', + tools: [calculatorTool], + }) + + const { items, result } = await collectGenerator(agent.stream('What is 789 * 321?')) + + // Should have reasoning content deltas + const reasoningDeltas = items.filter( + (item) => + item.type === 'modelStreamUpdateEvent' && + item.event.type === 'modelContentBlockDeltaEvent' && + item.event.delta.type === 'reasoningContentDelta' + ) + expect(reasoningDeltas.length).toBeGreaterThan(0) + + // Should have used the calculator tool + const toolUseMessage = agent.messages.find((msg) => + msg.content.some((block) => block.type === 'toolUseBlock' && block.name === 'calculator') + ) + expect(toolUseMessage).toBeDefined() + + // Verify reasoningSignature is present on tool use block + const toolUseBlock = toolUseMessage!.content.find( + (block): block is ToolUseBlock => block.type === 'toolUseBlock' && block.name === 'calculator' + ) + expect(toolUseBlock?.reasoningSignature).toBeDefined() + + // Should contain the correct result (789 * 321 = 253269) + expect(result.stopReason).toBe('endTurn') + const textContent = result.lastMessage.content.find((block) => block.type === 'textBlock') + expect(textContent).toBeDefined() + expect(textContent?.text).toMatch(/253269/) + + // Validate multi-turn works after tool use + await collectGenerator(agent.stream('What was the result?')) + }) + + it.skipIf(!supports.builtInTools)('handles built-in tools (code execution)', async () => { + const agent = new Agent({ + model: createModel('builtInTools' in models ? models.builtInTools : {}), + printer: false, + }) + + const result = await agent.invoke([ + new TextBlock('What is the sum of the first 50 prime numbers? Generate and run code to calculate it.'), + ]) + + expect(result.stopReason).toBe('endTurn') + const textContent = result.lastMessage.content.find((block) => block.type === 'textBlock') + expect(textContent).toBeDefined() + expect(textContent?.text).toMatch(/5117/) + + // Validate multi-turn works after built-in tool use + await collectGenerator(agent.stream('What was the result?')) + }) + }) +}) diff --git a/strands-ts/test/integ/conversation-manager/summarizing-conversation-manager.test.ts b/strands-ts/test/integ/conversation-manager/summarizing-conversation-manager.test.ts new file mode 100644 index 0000000000..7737f8c788 --- /dev/null +++ b/strands-ts/test/integ/conversation-manager/summarizing-conversation-manager.test.ts @@ -0,0 +1,195 @@ +import { describe, it, expect } from 'vitest' +import { + Agent, + ContextWindowOverflowError, + Message, + SummarizingConversationManager, + TextBlock, + ToolResultBlock, + ToolUseBlock, + tool, +} from '@strands-agents/sdk' +import { z } from 'zod' +import { bedrock } from '../__fixtures__/model-providers.js' + +function textMsg(role: 'user' | 'assistant', text: string): Message { + return new Message({ role, content: [new TextBlock(text)] }) +} + +const calculatorTool = tool({ + name: 'calculator', + description: 'Performs basic arithmetic operations', + inputSchema: z.object({ + operation: z.enum(['add', 'subtract', 'multiply', 'divide']), + a: z.number(), + b: z.number(), + }), + callback: async ({ operation, a, b }) => { + const ops = { add: a + b, subtract: a - b, multiply: a * b, divide: a / b } + return `Result: ${ops[operation]}` + }, +}) + +describe.skipIf(bedrock.skip)('SummarizingConversationManager Integration', () => { + it('summarizes older messages and agent remains functional after summarization', async () => { + const model = bedrock.createModel({ maxTokens: 1024 }) + const messages: Message[] = [ + textMsg('user', 'Hello, I am testing a conversation manager.'), + textMsg('assistant', 'Hello! I am here to help you test the conversation manager.'), + textMsg('user', 'Can you tell me about the history of computers?'), + textMsg( + 'assistant', + 'The history of computers spans many centuries, from the abacus to modern machines. Key milestones include the Pascaline (1642), ENIAC (1945), and the personal computer revolution of the 1980s.' + ), + textMsg('user', 'What were the first computers like?'), + textMsg( + 'assistant', + 'Early computers like ENIAC were enormous room-filling machines weighing about 30 tons, using thousands of vacuum tubes that generated tremendous heat and frequently failed.' + ), + ] + const lastTwo = messages.slice(-2) + + const manager = new SummarizingConversationManager({ + summaryRatio: 0.5, + preserveRecentMessages: 2, + }) + const agent = new Agent({ + model, + conversationManager: manager, + printer: false, + messages, + }) + + const result = await manager.reduce({ + agent, + model, + error: new ContextWindowOverflowError('overflow'), + }) + + expect(result).toBe(true) + // 6 messages, 50% ratio, preserve 2 → summarize 3, keep 3 → 1 summary + 3 = 4 + expect(agent.messages).toHaveLength(4) + + // First message should be the summary + const summary = agent.messages[0]! + expect(summary.role).toBe('user') + const summaryText = summary.content.find((b) => b.type === 'textBlock') as TextBlock + expect(summaryText).toBeDefined() + expect(summaryText.text.length).toBeGreaterThan(50) + + // Recent messages preserved + expect(agent.messages.slice(-2)).toEqual(lastTwo) + + // Agent should still be functional + const invokeResult = await agent.invoke('Thanks for the overview!') + expect(invokeResult.stopReason).toBe('endTurn') + expect(invokeResult.lastMessage.role).toBe('assistant') + }) + + it('keeps tool use/result pairs balanced after summarization', async () => { + const model = bedrock.createModel({ maxTokens: 1024 }) + // Messages indexed 0-13. With ratio 0.6 the initial split lands at index 8 + // (a toolResult for calc-1). The split-point adjuster walks forward past orphaned + // tool results to index 9 (plain text "25 + 37 = 62"), so indices 0-8 are summarized. + // The remaining messages (indices 9-13) include exactly one tool use/result pair + // (the weather tool at indices 11-12). + const messages: Message[] = [ + /* 0 */ textMsg('user', 'Hello, can you help me with some calculations?'), + /* 1 */ textMsg('assistant', 'Of course! I can help with calculations.'), + /* 2 */ textMsg('user', 'What is the current time?'), + /* 3 */ new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'get_time', toolUseId: 'time-1', input: {} })], + }), + /* 4 */ new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'time-1', + status: 'success', + content: [new TextBlock('2024-01-15 14:30:00')], + }), + ], + }), + /* 5 */ textMsg('assistant', 'The current time is 2024-01-15 14:30:00.'), + /* 6 */ textMsg('user', 'What is 25 + 37?'), + /* 7 */ new Message({ + role: 'assistant', + content: [ + new ToolUseBlock({ name: 'calculator', toolUseId: 'calc-1', input: { operation: 'add', a: 25, b: 37 } }), + ], + }), + /* 8 */ new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'calc-1', + status: 'success', + content: [new TextBlock('62')], + }), + ], + }), + /* 9 */ textMsg('assistant', '25 + 37 = 62'), + /* 10 */ textMsg('user', 'What is the weather in San Francisco?'), + /* 11 */ new Message({ + role: 'assistant', + content: [new ToolUseBlock({ name: 'get_weather', toolUseId: 'weather-1', input: { city: 'San Francisco' } })], + }), + /* 12 */ new Message({ + role: 'user', + content: [ + new ToolResultBlock({ + toolUseId: 'weather-1', + status: 'success', + content: [new TextBlock('Sunny and 72°F in San Francisco')], + }), + ], + }), + /* 13 */ textMsg('assistant', 'The weather in San Francisco is sunny and 72°F.'), + ] + + const manager = new SummarizingConversationManager({ + summaryRatio: 0.6, + preserveRecentMessages: 3, + }) + const agent = new Agent({ + model, + conversationManager: manager, + tools: [calculatorTool], + printer: false, + messages, + }) + + const result = await manager.reduce({ + agent, + model, + error: new ContextWindowOverflowError('overflow'), + }) + + expect(result).toBe(true) + // 9 summarized → 1 summary + 5 remaining = 6 + expect(agent.messages).toHaveLength(6) + + // Only the weather tool pair (indices 11-12) survives — time and calculator pairs were summarized + let toolUseCount = 0 + let toolResultCount = 0 + for (const msg of agent.messages) { + for (const block of msg.content) { + if (block.type === 'toolUseBlock') toolUseCount++ + if (block.type === 'toolResultBlock') toolResultCount++ + } + } + expect(toolUseCount).toBe(1) + expect(toolResultCount).toBe(1) + + // Agent should still work with tools + const invokeResult = await agent.invoke('Calculate 15 + 28 for me.') + expect(invokeResult.stopReason).toBe('endTurn') + + // Verify calculator tool was used + const hasCalcUse = agent.messages.some((msg) => + msg.content.some((block) => block.type === 'toolUseBlock' && block.name === 'calculator') + ) + expect(hasCalcUse).toBe(true) + }) +}) diff --git a/strands-ts/test/integ/environment.test.browser.ts b/strands-ts/test/integ/environment.test.browser.ts new file mode 100644 index 0000000000..4ad50752ce --- /dev/null +++ b/strands-ts/test/integ/environment.test.browser.ts @@ -0,0 +1,39 @@ +import { describe, it, expect } from 'vitest' + +import { isBrowser, isNode } from '$/sdk/__fixtures__/environment.js' + +describe('environment', () => { + describe('Browser compatibility', () => { + it('isNode should resolve to false', () => { + expect(isNode).toBe(false) + }) + it('has window object with expected properties', () => { + expect(window).toBeDefined() + expect(typeof window).toBe('object') + expect(window.location).toBeDefined() + expect(window.navigator).toBeDefined() + }) + + it('has document object with DOM methods', () => { + expect(document).toBeDefined() + expect(typeof document).toBe('object') + expect(typeof document.createElement).toBe('function') + expect(typeof document.querySelector).toBe('function') + }) + + it('has navigator object with browser information', () => { + expect(navigator).toBeDefined() + expect(typeof navigator).toBe('object') + expect(typeof navigator.userAgent).toBe('string') + expect(navigator.userAgent.length).toBeGreaterThan(0) + }) + + describe('environment detection', () => { + it('correctly identifies browser environment', () => { + expect(isBrowser).toBe(true) + expect(isNode).toBe(false) + expect(typeof window).toBe('object') + }) + }) + }) +}) diff --git a/strands-ts/test/integ/environment.test.node.ts b/strands-ts/test/integ/environment.test.node.ts new file mode 100644 index 0000000000..773f5372cf --- /dev/null +++ b/strands-ts/test/integ/environment.test.node.ts @@ -0,0 +1,21 @@ +import { describe, it, expect } from 'vitest' + +import { isBrowser, isNode } from '$/sdk/__fixtures__/environment.js' + +describe('environment', () => { + describe('Node.js compatibility', () => { + it('works in Node.js environment', () => { + // Test Node.js specific features are available + expect(typeof process).toBe('object') + expect(process.version).toBeDefined() + }) + }) + + describe('environment detection', () => { + it('correctly identifies Node.js environment', () => { + expect(isNode).toBe(true) + expect(isBrowser).toBe(false) + expect(typeof process).toBe('object') + }) + }) +}) diff --git a/strands-ts/test/integ/environment.test.ts b/strands-ts/test/integ/environment.test.ts new file mode 100644 index 0000000000..8ba7fe4c7a --- /dev/null +++ b/strands-ts/test/integ/environment.test.ts @@ -0,0 +1,28 @@ +import { describe, it, expect } from 'vitest' + +describe('environment', () => { + describe('JavaScript features', () => { + it('supports modern JavaScript features', () => { + // Test ES2022 features work + const testArray = [1, 2, 3] + const lastElement = testArray.at(-1) + expect(lastElement).toBe(3) + }) + + it('supports async/await functionality', async () => { + // Test async functionality works + const promise = Promise.resolve('test') + const result = await promise + expect(result).toBe('test') + }) + }) + + describe('TypeScript configuration', () => { + it('validates strict typing environment', () => { + // This test validates strict TypeScript configuration + // If this compiles and runs, strict typing is working + const testValue: string = 'test' + expect(typeof testValue).toBe('string') + }) + }) +}) diff --git a/strands-ts/test/integ/interrupt.test.ts b/strands-ts/test/integ/interrupt.test.ts new file mode 100644 index 0000000000..91697c5332 --- /dev/null +++ b/strands-ts/test/integ/interrupt.test.ts @@ -0,0 +1,110 @@ +import { describe, expect, it } from 'vitest' +import { Agent, BeforeToolCallEvent, tool } from '@strands-agents/sdk' +import { z } from 'zod' +import { bedrock } from './__fixtures__/model-providers.js' +import { resumeUntilDone, timeTool, weatherTool } from './__fixtures__/test-helpers.js' + +// Tool that interrupts to ask for the time +const interruptTimeTool = tool({ + name: 'time_tool', + description: 'Returns the current time', + inputSchema: z.object({}), + callback: async (_input, context) => { + return context!.interrupt({ name: 'test_interrupt', reason: 'need time' }) as string + }, +}) + +describe.skipIf(bedrock.skip)('Interrupts', () => { + describe('hook interrupts', () => { + function createAgentWithApprovalHook() { + const agent = new Agent({ + model: bedrock.createModel(), + printer: false, + tools: [timeTool, weatherTool], + }) + agent.addHook(BeforeToolCallEvent, (event) => { + if (event.toolUse.name === 'weather_tool') return + const response = event.interrupt({ name: 'test_interrupt', reason: 'need approval' }) + if (response !== 'APPROVE') { + event.cancel = 'tool rejected' + } + }) + return agent + } + + it('interrupts before tool call, resumes with approval', async () => { + const agent = createAgentWithApprovalHook() + + const result = await agent.invoke('What is the time and weather?') + + expect(result.stopReason).toBe('interrupt') + expect(result.interrupts).toBeDefined() + expect(result.interrupts!.length).toBeGreaterThanOrEqual(1) + + const interrupt = result.interrupts![0]! + expect(interrupt.name).toBe('test_interrupt') + expect(interrupt.reason).toBe('need approval') + + const finalResult = await resumeUntilDone(agent, result, () => 'APPROVE') + + expect(finalResult.stopReason).toBe('endTurn') + + const text = finalResult.lastMessage.content + .filter((b) => b.type === 'textBlock') + .map((b) => b.text) + .join(' ') + .toLowerCase() + expect(text).toMatch(/12:00|sunny/) + }) + + it('interrupts before tool call, resumes with rejection cancels tool', async () => { + const agent = createAgentWithApprovalHook() + + const result = await agent.invoke('What is the time and weather?') + expect(result.stopReason).toBe('interrupt') + + const finalResult = await resumeUntilDone(agent, result, () => 'REJECT') + expect(finalResult.stopReason).toBe('endTurn') + + // Verify at least one tool result was an error (the rejected tool) + const hasErrorResult = agent.messages.some( + (msg) => msg.role === 'user' && msg.content.some((b) => b.type === 'toolResultBlock' && b.status === 'error') + ) + expect(hasErrorResult).toBe(true) + }) + }) + + describe('tool interrupts', () => { + it('interrupts from tool callback, resumes with response', async () => { + const agent = new Agent({ + model: bedrock.createModel(), + printer: false, + tools: [interruptTimeTool, weatherTool], + }) + + const result = await agent.invoke('What is the time and weather?') + expect(result.stopReason).toBe('interrupt') + expect(result.interrupts).toBeDefined() + expect(result.interrupts!.length).toBeGreaterThanOrEqual(1) + + for (const interrupt of result.interrupts!) { + expect(interrupt.response).toBeUndefined() + } + + const finalResult = await resumeUntilDone(agent, result, (interrupt) => + interrupt.reason === 'need time' ? '12:01' : 'yes' + ) + + expect(finalResult.stopReason).toBe('endTurn') + + const lastAssistant = agent.messages.filter((m) => m.role === 'assistant').pop() + expect(lastAssistant).toBeDefined() + const finalText = lastAssistant!.content + .filter((b) => b.type === 'textBlock') + .map((b) => b.text) + .join(' ') + .toLowerCase() + expect(finalText).toMatch(/12:01|sunny/) + }) + }) +}) diff --git a/strands-ts/test/integ/interventions.test.ts b/strands-ts/test/integ/interventions.test.ts new file mode 100644 index 0000000000..0241acf44f --- /dev/null +++ b/strands-ts/test/integ/interventions.test.ts @@ -0,0 +1,1411 @@ +import { describe, expect, it } from 'vitest' +import { + Agent, + InterventionHandler, + InterventionActions, + AfterToolCallEvent, + BeforeInvocationEvent, + BeforeToolCallEvent, + InterruptResponseContent, + tool, +} from '@strands-agents/sdk' +import type { JSONValue } from '@strands-agents/sdk' +import { z } from 'zod' +import { collectGenerator } from '$/sdk/__fixtures__/model-test-helpers.js' +import { allProviders } from './__fixtures__/model-providers.js' +import { + countToolResults, + getToolResultText, + resumeUntilDone, + timeTool, + weatherTool, + echoTool, +} from './__fixtures__/test-helpers.js' + +// ========== Intervention Handler Implementations ========== + +class DenyAllToolsHandler extends InterventionHandler { + readonly name = 'deny-all-tools' + + override beforeToolCall() { + return InterventionActions.deny('All tool use is blocked by policy') + } +} + +class DenySpecificToolHandler extends InterventionHandler { + readonly name = 'deny-specific-tool' + private readonly blockedTool: string + + constructor(blockedTool: string) { + super() + this.blockedTool = blockedTool + } + + override beforeToolCall(event: BeforeToolCallEvent) { + if (event.toolUse.name === this.blockedTool) { + return InterventionActions.deny(`Tool "${this.blockedTool}" is not allowed`) + } + return InterventionActions.proceed() + } +} + +class ConfirmToolHandler extends InterventionHandler { + readonly name = 'confirm-tool' + + override beforeToolCall(event: BeforeToolCallEvent) { + return InterventionActions.confirm(`Approve use of ${event.toolUse.name}?`) + } +} + +class PreemptiveConfirmHandler extends InterventionHandler { + readonly name = 'preemptive-confirm' + private readonly answer: JSONValue + + constructor(answer: JSONValue) { + super() + this.answer = answer + } + + override beforeToolCall(event: BeforeToolCallEvent) { + return InterventionActions.confirm(`Approve ${event.toolUse.name}?`, { response: this.answer }) + } +} + +class GuideBeforeToolHandler extends InterventionHandler { + readonly name = 'guide-before-tool' + private readonly feedback: string + + constructor(feedback: string) { + super() + this.feedback = feedback + } + + override beforeToolCall() { + return InterventionActions.guide(this.feedback) + } +} + +class GuideAfterModelHandler extends InterventionHandler { + readonly name = 'guide-after-model' + private readonly feedback: string + private callCount = 0 + private readonly maxGuides: number + + constructor(feedback: string, maxGuides = 1) { + super() + this.feedback = feedback + this.maxGuides = maxGuides + } + + override afterModelCall() { + if (this.callCount < this.maxGuides) { + this.callCount++ + return InterventionActions.guide(this.feedback) + } + return InterventionActions.proceed() + } +} + +class TransformToolInputHandler extends InterventionHandler { + readonly name = 'transform-tool-input' + private readonly transformFn: (input: Record) => Record + + constructor(transformFn: (input: Record) => Record) { + super() + this.transformFn = transformFn + } + + override beforeToolCall(event: BeforeToolCallEvent) { + const transformed = this.transformFn(event.toolUse.input as Record) + return InterventionActions.transform((e) => { + ;(e as BeforeToolCallEvent).toolUse.input = transformed as JSONValue + }) + } +} + +class TransformToolResultHandler extends InterventionHandler { + readonly name = 'transform-tool-result' + + override afterToolCall(_event: AfterToolCallEvent) { + return InterventionActions.transform((e) => { + const afterEvent = e as AfterToolCallEvent + if (afterEvent.result.status === 'success') { + const content = afterEvent.result.content + for (const block of content) { + if (block.type === 'textBlock') { + Object.assign(block, { text: block.text.replace(/\d+/g, '[REDACTED]') }) + } + } + } + }) + } +} + +class DenyInvocationHandler extends InterventionHandler { + readonly name = 'deny-invocation' + + override beforeInvocation(_event: BeforeInvocationEvent) { + return InterventionActions.deny('Invocation blocked by policy') + } +} + +class DenyBeforeModelHandler extends InterventionHandler { + readonly name = 'deny-before-model' + + override beforeModelCall() { + return InterventionActions.deny('Model call blocked by intervention') + } +} + +class ErrorThrowingHandler extends InterventionHandler { + readonly name = 'error-throw' + override readonly onError = 'throw' as const + + override beforeToolCall(): never { + throw new Error('Handler exploded') + } +} + +class ErrorProceedHandler extends InterventionHandler { + readonly name = 'error-proceed' + override readonly onError = 'proceed' as const + + override beforeToolCall(): never { + throw new Error('Handler exploded but should continue') + } +} + +class ErrorDenyHandler extends InterventionHandler { + readonly name = 'error-deny' + override readonly onError = 'deny' as const + + override beforeToolCall(): never { + throw new Error('Handler exploded and should deny') + } +} + +class CustomEvaluateConfirmHandler extends InterventionHandler { + readonly name = 'custom-evaluate-confirm' + + override beforeToolCall(event: BeforeToolCallEvent) { + return InterventionActions.confirm(`Approve ${event.toolUse.name}?`, { + evaluate: (response) => response === 'MAGIC_WORD', + }) + } +} + +// ========== Tests ========== + +describe.each(allProviders)('Interventions with $name', ({ name, skip, createModel, supports }) => { + describe.skipIf(skip || !supports.tools)(`${name} Intervention Integration Tests`, () => { + describe('deny action', () => { + it('deny on beforeToolCall blocks tool and agent completes gracefully', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new DenyAllToolsHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'error')).toBeGreaterThanOrEqual(1) + }) + + it('deny on beforeToolCall only blocks the specified tool', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: + 'When asked about time and weather, use BOTH time_tool AND weather_tool. Always use both tools.', + tools: [timeTool, weatherTool], + interventions: [new DenySpecificToolHandler('time_tool')], + }) + + const result = await agent.invoke('What is the time and weather?') + expect(result.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'error')).toBeGreaterThanOrEqual(1) + expect(countToolResults(agent.messages, 'success')).toBeGreaterThanOrEqual(1) + }) + + it('deny on beforeInvocation cancels the invocation', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + tools: [timeTool], + interventions: [new DenyInvocationHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.content.some((b) => b.type === 'textBlock' && b.text.includes('DENIED'))).toBe(true) + }) + + it('deny on beforeModelCall prevents model from being called', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + tools: [timeTool], + interventions: [new DenyBeforeModelHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.content.some((b) => b.type === 'textBlock' && b.text.includes('DENIED'))).toBe(true) + }) + }) + + describe('confirm action', () => { + it('confirm pauses agent execution and resumes with approval', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new ConfirmToolHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('interrupt') + expect(result.interrupts).toBeDefined() + expect(result.interrupts!.length).toBeGreaterThanOrEqual(1) + expect(result.interrupts![0]!.name).toBe('confirm-tool') + + const finalResult = await resumeUntilDone(agent, result, () => 'yes') + expect(finalResult.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'success')).toBeGreaterThanOrEqual(1) + }) + + it('confirm with denial blocks tool execution', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new ConfirmToolHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('interrupt') + + const finalResult = await resumeUntilDone(agent, result, () => 'no') + expect(finalResult.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'error')).toBeGreaterThanOrEqual(1) + }) + + it('confirm with preemptive approval does not pause agent', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new PreemptiveConfirmHandler('yes')], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'success')).toBeGreaterThanOrEqual(1) + }) + + it('confirm with preemptive denial blocks tool without pausing', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new PreemptiveConfirmHandler('no')], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'error')).toBeGreaterThanOrEqual(1) + }) + + it('confirm with custom evaluate uses custom approval logic', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new CustomEvaluateConfirmHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('interrupt') + + // 'yes' would pass default evaluate but fails custom (requires 'MAGIC_WORD') + const deniedResult = await resumeUntilDone(agent, result, () => 'yes') + expect(deniedResult.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'error')).toBeGreaterThanOrEqual(1) + }) + + it('confirm with custom evaluate accepts custom approval value', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new CustomEvaluateConfirmHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('interrupt') + + const approvedResult = await resumeUntilDone(agent, result, () => 'MAGIC_WORD') + expect(approvedResult.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'success')).toBeGreaterThanOrEqual(1) + }) + }) + + describe('guide action', () => { + it('guide on beforeToolCall cancels tool with feedback for model', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new GuideBeforeToolHandler('Please use a different approach')], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'error')).toBeGreaterThanOrEqual(1) + }) + + it('guide on afterModelCall triggers retry with feedback', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Answer questions directly without using tools.', + tools: [], + interventions: [new GuideAfterModelHandler('Please be more specific in your answer')], + }) + + const result = await agent.invoke('Hello') + expect(result.stopReason).toBe('endTurn') + + const guidanceMessages = agent.messages.filter( + (m) => + m.role === 'user' && m.content.some((b) => b.type === 'textBlock' && b.text.includes('be more specific')) + ) + expect(guidanceMessages.length).toBeGreaterThanOrEqual(1) + }) + + it('guide on beforeModelCall injects feedback as user message', async () => { + let guideCalled = false + + class OneTimeGuide extends InterventionHandler { + readonly name = 'onetime-guide' + override beforeModelCall() { + if (!guideCalled) { + guideCalled = true + return InterventionActions.guide('Remember to be concise') + } + return InterventionActions.proceed() + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + tools: [], + interventions: [new OneTimeGuide()], + }) + + const result = await agent.invoke('Tell me a joke') + expect(result.stopReason).toBe('endTurn') + + const guidanceMessages = agent.messages.filter( + (m) => m.role === 'user' && m.content.some((b) => b.type === 'textBlock' && b.text.includes('be concise')) + ) + expect(guidanceMessages.length).toBeGreaterThanOrEqual(1) + }) + }) + + describe('transform action', () => { + it('transform on beforeToolCall modifies tool input', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: + 'Use the echo_tool to echo messages. When asked to echo something, call echo_tool with that message.', + tools: [echoTool], + interventions: [ + new TransformToolInputHandler((input) => ({ + ...input, + message: `[TRANSFORMED] ${input.message ?? 'hello'}`, + })), + ], + }) + + const result = await agent.invoke('Echo the message "hello world"') + expect(result.stopReason).toBe('endTurn') + + expect(getToolResultText(agent.messages, 'success')).toContain('[TRANSFORMED]') + }) + + it('transform on afterToolCall modifies tool output', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new TransformToolResultHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('endTurn') + + expect(getToolResultText(agent.messages, 'success')).toContain('[REDACTED]') + }) + }) + + describe('multiple handlers', () => { + it('handlers execute in registration order and first deny short-circuits', async () => { + let secondHandlerCalled = false + + class SecondHandler extends InterventionHandler { + readonly name = 'second-handler' + override beforeToolCall() { + secondHandlerCalled = true + return InterventionActions.proceed() + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new DenyAllToolsHandler(), new SecondHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('endTurn') + expect(secondHandlerCalled).toBe(false) + }) + + it('proceed from first handler allows second handler to evaluate', async () => { + let secondHandlerCalled = false + + class ProceedFirstHandler extends InterventionHandler { + readonly name = 'proceed-first' + override beforeToolCall() { + return InterventionActions.proceed() + } + } + + class TrackingDenyHandler extends InterventionHandler { + readonly name = 'tracking-deny' + override beforeToolCall() { + secondHandlerCalled = true + return InterventionActions.deny('Blocked by second handler') + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new ProceedFirstHandler(), new TrackingDenyHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('endTurn') + expect(secondHandlerCalled).toBe(true) + }) + + it('transform then deny: transform applies before deny blocks', async () => { + let transformApplied = false + + class TrackingTransformHandler extends InterventionHandler { + readonly name = 'tracking-transform' + override beforeToolCall() { + return InterventionActions.transform(() => { + transformApplied = true + }) + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new TrackingTransformHandler(), new DenyAllToolsHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('endTurn') + expect(transformApplied).toBe(true) + }) + }) + + describe('error handling', () => { + it('onError=throw propagates handler errors', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new ErrorThrowingHandler()], + }) + + await expect(agent.invoke('What time is it?')).rejects.toThrow('Handler exploded') + }) + + it('onError=proceed swallows error and allows tool to run', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new ErrorProceedHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'success')).toBeGreaterThanOrEqual(1) + }) + + it('onError=deny fails closed and blocks the tool', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new ErrorDenyHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'error')).toBeGreaterThanOrEqual(1) + }) + }) + + describe('async handlers', () => { + it('awaits async handler returning deny', async () => { + class AsyncDenyHandler extends InterventionHandler { + readonly name = 'async-deny' + override async beforeToolCall() { + await new Promise((resolve) => setTimeout(resolve, 10)) + return InterventionActions.deny('Async denial') + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new AsyncDenyHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'error')).toBeGreaterThanOrEqual(1) + }) + + it('awaits async handler returning confirm', async () => { + class AsyncConfirmHandler extends InterventionHandler { + readonly name = 'async-confirm' + override async beforeToolCall(event: BeforeToolCallEvent) { + await new Promise((resolve) => setTimeout(resolve, 10)) + return InterventionActions.confirm(`Approve ${event.toolUse.name}?`) + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new AsyncConfirmHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('interrupt') + + const finalResult = await resumeUntilDone(agent, result, () => 'yes') + expect(finalResult.stopReason).toBe('endTurn') + }) + }) + + describe('multi-lifecycle handlers', () => { + it('handler can implement multiple lifecycle methods', async () => { + let beforeToolCalled = false + let afterToolCalled = false + + class MultiLifecycleHandler extends InterventionHandler { + readonly name = 'multi-lifecycle' + + override beforeToolCall() { + beforeToolCalled = true + return InterventionActions.proceed() + } + + override afterToolCall() { + afterToolCalled = true + return InterventionActions.proceed() + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new MultiLifecycleHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('endTurn') + expect(beforeToolCalled).toBe(true) + expect(afterToolCalled).toBe(true) + }) + + it('handler evaluates each tool call independently', async () => { + let toolCallCount = 0 + + class CountingHandler extends InterventionHandler { + readonly name = 'counting-handler' + override beforeToolCall() { + toolCallCount++ + return InterventionActions.proceed() + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: + 'When asked about time and weather, you MUST call BOTH time_tool AND weather_tool. Always use both.', + tools: [timeTool, weatherTool], + interventions: [new CountingHandler()], + }) + + const result = await agent.invoke('What is the time and weather?') + expect(result.stopReason).toBe('endTurn') + expect(toolCallCount).toBeGreaterThanOrEqual(2) + }) + }) + + describe('confirm with interrupt/resume flow', () => { + it('confirm on multiple tool calls collects interrupts for each', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: + 'When asked about time and weather, you MUST call BOTH time_tool AND weather_tool. Always use both.', + tools: [timeTool, weatherTool], + interventions: [new ConfirmToolHandler()], + }) + + const result = await agent.invoke('What is the time and weather?') + expect(result.stopReason).toBe('interrupt') + expect(result.interrupts!.length).toBeGreaterThanOrEqual(1) + + const finalResult = await resumeUntilDone(agent, result, () => 'yes') + expect(finalResult.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'success')).toBeGreaterThanOrEqual(1) + }) + + it('confirm interrupt includes handler name and prompt as reason', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new ConfirmToolHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('interrupt') + + const interrupt = result.interrupts![0]! + expect(interrupt.name).toBe('confirm-tool') + expect(interrupt.reason).toContain('Approve') + }) + + it('resume with InterruptResponseContent instances works', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new ConfirmToolHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('interrupt') + + const responses = result.interrupts!.map( + (interrupt) => + new InterruptResponseContent({ + interruptId: interrupt.id, + response: 'yes', + }) + ) + + const resumed = await agent.invoke(responses) + const finalResult = await resumeUntilDone(agent, resumed, () => 'yes') + expect(finalResult.stopReason).toBe('endTurn') + }) + }) + + describe('guide accumulation across multiple handlers', () => { + it('accumulates feedback from multiple guide handlers into one cancellation', async () => { + class SecurityGuide extends InterventionHandler { + readonly name = 'security-guide' + override beforeToolCall() { + return InterventionActions.guide('Ensure input is sanitized') + } + } + + class ComplianceGuide extends InterventionHandler { + readonly name = 'compliance-guide' + override beforeToolCall() { + return InterventionActions.guide('Check compliance requirements') + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new SecurityGuide(), new ComplianceGuide()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'error')).toBeGreaterThanOrEqual(1) + + const errorText = getToolResultText(agent.messages, 'error') + expect(errorText).toContain('sanitized') + expect(errorText).toContain('compliance') + }) + }) + + describe('transform chaining', () => { + it('multiple transforms apply in sequence and later handlers see mutations', async () => { + class PrefixTransform extends InterventionHandler { + readonly name = 'prefix-transform' + override beforeToolCall(event: BeforeToolCallEvent) { + const input = event.toolUse.input as Record + return InterventionActions.transform((e) => { + ;(e as BeforeToolCallEvent).toolUse.input = { + ...input, + message: `[PREFIX] ${input.message ?? ''}`, + } as JSONValue + }) + } + } + + class SuffixTransform extends InterventionHandler { + readonly name = 'suffix-transform' + override beforeToolCall(_event: BeforeToolCallEvent) { + return InterventionActions.transform((e) => { + const current = (e as BeforeToolCallEvent).toolUse.input as Record + ;(e as BeforeToolCallEvent).toolUse.input = { + ...current, + message: `${current.message || ''} [SUFFIX]`, + } as JSONValue + }) + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use echo_tool to echo messages. Call echo_tool with the message provided.', + tools: [echoTool], + interventions: [new PrefixTransform(), new SuffixTransform()], + }) + + const result = await agent.invoke('Echo "test"') + expect(result.stopReason).toBe('endTurn') + + const resultText = getToolResultText(agent.messages, 'success') + expect(resultText).toContain('[PREFIX]') + expect(resultText).toContain('[SUFFIX]') + }) + + it('transform on afterToolCall can redact sensitive data before model sees it', async () => { + const sensitiveDataTool = tool({ + name: 'user_data_tool', + description: 'Returns user data. Always call this tool when asked about user info.', + inputSchema: z.object({}), + callback: async () => 'SSN: 123-45-6789, Name: John Doe', + }) + + class RedactSSNHandler extends InterventionHandler { + readonly name = 'redact-ssn' + override afterToolCall(_event: AfterToolCallEvent) { + return InterventionActions.transform((e) => { + const afterEvent = e as AfterToolCallEvent + for (const block of afterEvent.result.content) { + if (block.type === 'textBlock') { + Object.assign(block, { text: block.text.replace(/\d{3}-\d{2}-\d{4}/g, '***-**-****') }) + } + } + }) + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use user_data_tool to get user information.', + tools: [sensitiveDataTool], + interventions: [new RedactSSNHandler()], + }) + + const result = await agent.invoke('What is the user data?') + expect(result.stopReason).toBe('endTurn') + + const resultText = getToolResultText(agent.messages, 'success') + expect(resultText).toContain('***-**-****') + expect(resultText).not.toContain('123-45-6789') + }) + }) + + describe('conditional interventions based on tool input', () => { + class InputValidationHandler extends InterventionHandler { + readonly name = 'input-validation' + override beforeToolCall(event: BeforeToolCallEvent) { + const input = event.toolUse.input as Record + if (typeof input.message === 'string' && input.message.includes('DROP TABLE')) { + return InterventionActions.deny('SQL injection detected in tool input') + } + return InterventionActions.proceed() + } + } + + it('denies tool call based on input content', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use echo_tool to echo exactly what the user says. Pass their exact message.', + tools: [echoTool], + interventions: [new InputValidationHandler()], + }) + + const result = await agent.invoke('Echo this: DROP TABLE users') + expect(result.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'error')).toBeGreaterThanOrEqual(1) + }) + + it('allows tool call when input passes validation', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use echo_tool to echo what the user says.', + tools: [echoTool], + interventions: [new InputValidationHandler()], + }) + + const result = await agent.invoke('Echo: hello world') + expect(result.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'success')).toBeGreaterThanOrEqual(1) + }) + }) + + describe('stateful handlers with appState', () => { + it('handler reads appState to make policy decisions', async () => { + class RateLimitHandler extends InterventionHandler { + readonly name = 'rate-limit' + override beforeToolCall(event: BeforeToolCallEvent) { + const callCount = (event.agent.appState.get('toolCallCount') as number) ?? 0 + event.agent.appState.set('toolCallCount', callCount + 1) + if (callCount >= 2) { + return InterventionActions.deny('Rate limit exceeded: max 2 tool calls per session') + } + return InterventionActions.proceed() + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: + 'Use the time_tool to answer time questions. If the tool fails, just say you cannot get the time.', + tools: [timeTool], + appState: { toolCallCount: 0 }, + interventions: [new RateLimitHandler()], + }) + + const result1 = await agent.invoke('What time is it?') + expect(result1.stopReason).toBe('endTurn') + expect(agent.appState.get('toolCallCount')).toBeGreaterThanOrEqual(1) + + await agent.invoke('What time is it again?') + const finalCount = agent.appState.get('toolCallCount') as number + expect(finalCount).toBeGreaterThanOrEqual(2) + }) + + it('handler uses appState for per-tool allow list', async () => { + class AllowListHandler extends InterventionHandler { + readonly name = 'allow-list' + override beforeToolCall(event: BeforeToolCallEvent) { + const allowedTools = (event.agent.appState.get('allowedTools') as string[]) ?? [] + if (!allowedTools.includes(event.toolUse.name)) { + return InterventionActions.deny(`Tool "${event.toolUse.name}" is not in the allow list`) + } + return InterventionActions.proceed() + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool, weatherTool], + appState: { allowedTools: ['weather_tool'] }, + interventions: [new AllowListHandler()], + }) + + const result = await agent.invoke('What is the time and weather?') + expect(result.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'error')).toBeGreaterThanOrEqual(1) + }) + }) + + describe('confirm with varied response types', () => { + it('confirm accepts boolean true as approval', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new ConfirmToolHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('interrupt') + + const finalResult = await resumeUntilDone(agent, result, () => true) + expect(finalResult.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'success')).toBeGreaterThanOrEqual(1) + }) + + it('confirm rejects boolean false', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new ConfirmToolHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('interrupt') + + const finalResult = await resumeUntilDone(agent, result, () => false) + expect(finalResult.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'error')).toBeGreaterThanOrEqual(1) + }) + + it('confirm rejects null response', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new ConfirmToolHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('interrupt') + + const finalResult = await resumeUntilDone(agent, result, () => null) + expect(finalResult.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'error')).toBeGreaterThanOrEqual(1) + }) + + it('confirm accepts case-insensitive YES', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new ConfirmToolHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('interrupt') + + const finalResult = await resumeUntilDone(agent, result, () => 'YES') + expect(finalResult.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'success')).toBeGreaterThanOrEqual(1) + }) + + it('confirm accepts whitespace-padded " yes "', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new ConfirmToolHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('interrupt') + + const finalResult = await resumeUntilDone(agent, result, () => ' yes ') + expect(finalResult.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'success')).toBeGreaterThanOrEqual(1) + }) + + it('confirm rejects empty string', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new ConfirmToolHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('interrupt') + + const finalResult = await resumeUntilDone(agent, result, () => '') + expect(finalResult.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'error')).toBeGreaterThanOrEqual(1) + }) + }) + + describe('intervention interaction with agent lifecycle', () => { + it('deny on beforeInvocation prevents any model or tool interaction', async () => { + let modelCalled = false + + class TrackingDenyInvocation extends InterventionHandler { + readonly name = 'tracking-deny-invocation' + override beforeInvocation() { + return InterventionActions.deny('Blocked at invocation level') + } + override beforeModelCall() { + modelCalled = true + return InterventionActions.proceed() + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + tools: [timeTool], + interventions: [new TrackingDenyInvocation()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('endTurn') + expect(modelCalled).toBe(false) + expect(result.lastMessage.content.some((b) => b.type === 'textBlock' && b.text.includes('DENIED'))).toBe(true) + }) + + it('intervention handler can inspect tool name to apply per-tool policies', async () => { + const toolDecisions: Record = {} + + class PerToolPolicyHandler extends InterventionHandler { + readonly name = 'per-tool-policy' + override beforeToolCall(event: BeforeToolCallEvent) { + if (event.toolUse.name === 'time_tool') { + toolDecisions[event.toolUse.name] = 'denied' + return InterventionActions.deny('time_tool requires elevated permissions') + } + toolDecisions[event.toolUse.name] = 'allowed' + return InterventionActions.proceed() + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: + 'When asked about time and weather, you MUST call BOTH time_tool AND weather_tool. Always use both.', + tools: [timeTool, weatherTool], + interventions: [new PerToolPolicyHandler()], + }) + + const result = await agent.invoke('What is the time and weather?') + expect(result.stopReason).toBe('endTurn') + expect(toolDecisions['time_tool']).toBe('denied') + expect(toolDecisions['weather_tool']).toBe('allowed') + }) + + it('afterModelCall guide causes model to retry with guidance injected', async () => { + let attemptCount = 0 + + class RetryOnceGuide extends InterventionHandler { + readonly name = 'retry-once-guide' + override afterModelCall() { + attemptCount++ + if (attemptCount === 1) { + return InterventionActions.guide('Please include the word VERIFIED in your response') + } + return InterventionActions.proceed() + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + tools: [], + interventions: [new RetryOnceGuide()], + }) + + const result = await agent.invoke('Say hello') + expect(result.stopReason).toBe('endTurn') + expect(attemptCount).toBeGreaterThanOrEqual(2) + + const guidanceMessages = agent.messages.filter( + (m) => m.role === 'user' && m.content.some((b) => b.type === 'textBlock' && b.text.includes('VERIFIED')) + ) + expect(guidanceMessages.length).toBeGreaterThanOrEqual(1) + }) + + it('intervention runs on every tool call in a multi-tool response', async () => { + const toolsSeen: string[] = [] + + class TrackAllToolsHandler extends InterventionHandler { + readonly name = 'track-all-tools' + override beforeToolCall(event: BeforeToolCallEvent) { + toolsSeen.push(event.toolUse.name) + return InterventionActions.proceed() + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: + 'When asked about time and weather, you MUST call BOTH time_tool AND weather_tool in the same response.', + tools: [timeTool, weatherTool], + interventions: [new TrackAllToolsHandler()], + }) + + const result = await agent.invoke('What is the time and weather?') + expect(result.stopReason).toBe('endTurn') + expect(toolsSeen.length).toBeGreaterThanOrEqual(2) + expect(toolsSeen).toContain('time_tool') + expect(toolsSeen).toContain('weather_tool') + }) + }) + + describe('mixed handler strategies', () => { + it('confirm handler followed by transform: approved tool gets transformed input', async () => { + class ApproveAndTransform extends InterventionHandler { + readonly name = 'approve-and-transform' + override beforeToolCall(event: BeforeToolCallEvent) { + return InterventionActions.confirm(`Approve ${event.toolUse.name}?`) + } + } + + class AddMetadata extends InterventionHandler { + readonly name = 'add-metadata' + override beforeToolCall(event: BeforeToolCallEvent) { + const input = event.toolUse.input as Record + return InterventionActions.transform((e) => { + ;(e as BeforeToolCallEvent).toolUse.input = { + ...input, + message: `[AUDITED] ${input.message ?? 'data'}`, + } as JSONValue + }) + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use echo_tool to echo messages.', + tools: [echoTool], + interventions: [new ApproveAndTransform(), new AddMetadata()], + }) + + const result = await agent.invoke('Echo "test data"') + expect(result.stopReason).toBe('interrupt') + + const finalResult = await resumeUntilDone(agent, result, () => 'yes') + expect(finalResult.stopReason).toBe('endTurn') + + expect(getToolResultText(agent.messages, 'success')).toContain('[AUDITED]') + }) + + it('denied confirm short-circuits and skips subsequent transform', async () => { + let transformApplied = false + + class ConfirmFirst extends InterventionHandler { + readonly name = 'confirm-first' + override beforeToolCall(event: BeforeToolCallEvent) { + return InterventionActions.confirm(`Approve ${event.toolUse.name}?`) + } + } + + class TransformSecond extends InterventionHandler { + readonly name = 'transform-second' + override beforeToolCall() { + return InterventionActions.transform(() => { + transformApplied = true + }) + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new ConfirmFirst(), new TransformSecond()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('interrupt') + + const finalResult = await resumeUntilDone(agent, result, () => 'no') + expect(finalResult.stopReason).toBe('endTurn') + expect(transformApplied).toBe(false) + }) + + it('proceed + proceed + deny: first two handlers pass, third blocks', async () => { + const handlerLog: string[] = [] + + class FirstProceed extends InterventionHandler { + readonly name = 'first-proceed' + override beforeToolCall() { + handlerLog.push('first') + return InterventionActions.proceed() + } + } + + class SecondProceed extends InterventionHandler { + readonly name = 'second-proceed' + override beforeToolCall() { + handlerLog.push('second') + return InterventionActions.proceed() + } + } + + class ThirdDeny extends InterventionHandler { + readonly name = 'third-deny' + override beforeToolCall() { + handlerLog.push('third') + return InterventionActions.deny('Blocked by third handler') + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new FirstProceed(), new SecondProceed(), new ThirdDeny()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('endTurn') + expect(handlerLog).toEqual(['first', 'second', 'third']) + expect(countToolResults(agent.messages, 'error')).toBeGreaterThanOrEqual(1) + }) + }) + + describe('error recovery scenarios', () => { + it('onError=proceed followed by working handler: agent uses second handler result', async () => { + class FailingHandler extends InterventionHandler { + readonly name = 'failing-handler' + override readonly onError = 'proceed' as const + override beforeToolCall(): never { + throw new Error('External service timeout') + } + } + + class WorkingDenyHandler extends InterventionHandler { + readonly name = 'working-deny' + override beforeToolCall() { + return InterventionActions.deny('Blocked by working handler') + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new FailingHandler(), new WorkingDenyHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'error')).toBeGreaterThanOrEqual(1) + }) + + it('onError=deny short-circuits before later handlers run', async () => { + let laterHandlerCalled = false + + class FailDenyHandler extends InterventionHandler { + readonly name = 'fail-deny' + override readonly onError = 'deny' as const + override beforeToolCall(): never { + throw new Error('Auth service down') + } + } + + class LaterHandler extends InterventionHandler { + readonly name = 'later-handler' + override beforeToolCall() { + laterHandlerCalled = true + return InterventionActions.proceed() + } + } + + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new FailDenyHandler(), new LaterHandler()], + }) + + const result = await agent.invoke('What time is it?') + expect(result.stopReason).toBe('endTurn') + expect(laterHandlerCalled).toBe(false) + }) + }) + + describe('streaming compatibility', () => { + it('interventions work correctly when using agent.stream()', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new DenyAllToolsHandler()], + }) + + const { items, result } = await collectGenerator(agent.stream('What time is it?')) + + expect(items.length).toBeGreaterThan(0) + expect(result.stopReason).toBe('endTurn') + expect(countToolResults(agent.messages, 'error')).toBeGreaterThanOrEqual(1) + }) + + it('confirm interrupt works via stream API', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + systemPrompt: 'Use the time_tool to answer time questions.', + tools: [timeTool], + interventions: [new ConfirmToolHandler()], + }) + + const { result } = await collectGenerator(agent.stream('What time is it?')) + expect(result.stopReason).toBe('interrupt') + expect(result.interrupts).toBeDefined() + + const finalResult = await resumeUntilDone(agent, result, () => 'yes') + expect(finalResult.stopReason).toBe('endTurn') + }) + }) + }) +}) diff --git a/strands-ts/test/integ/mcp/mcp-tasks.test.node.ts b/strands-ts/test/integ/mcp/mcp-tasks.test.node.ts new file mode 100644 index 0000000000..c0182bcbf0 --- /dev/null +++ b/strands-ts/test/integ/mcp/mcp-tasks.test.node.ts @@ -0,0 +1,225 @@ +import { describe, it, expect, beforeAll, afterAll } from 'vitest' +import { McpClient, Agent } from '@strands-agents/sdk' +import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' +import { startTaskHTTPServer, type TaskHttpServerInfo } from '../__fixtures__/test-mcp-task-server.js' +import { startHTTPServer, type HttpServerInfo } from '../__fixtures__/test-mcp-server.js' +import { bedrock } from '../__fixtures__/model-providers.js' +import { hasToolUse, countToolResults } from '../__fixtures__/test-helpers.js' + +import type { TasksConfig } from '@strands-agents/sdk' + +/** + * Creates a connected McpClient for the given server URL. + * Returns the client - caller is responsible for disconnecting. + * @param serverUrl - The URL of the MCP server + * @param appName - The application name for the client + * @param tasksConfig - Optional tasks configuration. When provided, enables task-based tool invocation. + */ +function createClient(serverUrl: string, appName: string, tasksConfig?: TasksConfig): McpClient { + return new McpClient({ + applicationName: appName, + transport: new StreamableHTTPClientTransport(new URL(serverUrl)), + ...(tasksConfig !== undefined && { tasksConfig }), + }) +} + +describe('MCP Task Integration Tests', () => { + let taskServerInfo: TaskHttpServerInfo | undefined + let nonTaskServerInfo: HttpServerInfo | undefined + + beforeAll(async () => { + // Start both servers in parallel + ;[taskServerInfo, nonTaskServerInfo] = await Promise.all([startTaskHTTPServer(), startHTTPServer()]) + }, 30000) + + afterAll(async () => { + // Clean up both servers + await Promise.all([taskServerInfo?.close(), nonTaskServerInfo?.close()]) + }, 30000) + + describe('McpClient.callTool() with Task-Enabled Server', () => { + it('extracts result from task tool that completes immediately', async () => { + if (!taskServerInfo) throw new Error('Task server not started') + + const client = createClient(taskServerInfo.url, 'test-task-client', {}) + try { + await client.connect() + const tools = await client.listTools() + const instantTool = tools.find((t) => t.name === 'instant_task') + expect(instantTool).toBeDefined() + + // McpClient.callTool uses callToolStream internally + const result = await client.callTool(instantTool!, { value: 'hello from instant task' }) + + expect(result).toMatchObject({ + content: expect.arrayContaining([expect.objectContaining({ type: 'text', text: 'hello from instant task' })]), + }) + } finally { + await client.disconnect() + } + }, 30000) + + it('extracts result from long-running task with progress updates', async () => { + if (!taskServerInfo) throw new Error('Task server not started') + + const client = createClient(taskServerInfo.url, 'test-task-client', {}) + try { + await client.connect() + const tools = await client.listTools() + const longRunningTool = tools.find((t) => t.name === 'long_running_task') + expect(longRunningTool).toBeDefined() + + // McpClient.callTool should wait for the task to complete and return the final result + const result = await client.callTool(longRunningTool!, { + duration: 300, + message: 'Long task completed successfully!', + }) + + expect(result).toMatchObject({ + content: expect.arrayContaining([ + expect.objectContaining({ type: 'text', text: 'Long task completed successfully!' }), + ]), + }) + } finally { + await client.disconnect() + } + }, 30000) + + it('throws error for failed tasks (MCP SDK behavior)', async () => { + if (!taskServerInfo) throw new Error('Task server not started') + + const client = createClient(taskServerInfo.url, 'test-task-client', {}) + try { + await client.connect() + const tools = await client.listTools() + const failingTool = tools.find((t) => t.name === 'failing_task') + expect(failingTool).toBeDefined() + + // McpClient.callTool uses takeResult() which throws on task failure + await expect(client.callTool(failingTool!, { error_message: 'This task failed on purpose!' })).rejects.toThrow( + /failed/i + ) + } finally { + await client.disconnect() + } + }, 30000) + }) + + describe('McpClient.callTool() with Non-Task Server (Backward Compatibility)', () => { + it('extracts result from regular (non-task) tools', async () => { + if (!nonTaskServerInfo) throw new Error('Non-task server not started') + + const client = createClient(nonTaskServerInfo.url, 'test-compat-client') + try { + await client.connect() + const tools = await client.listTools() + const echoTool = tools.find((t) => t.name === 'echo') + expect(echoTool).toBeDefined() + + const result = await client.callTool(echoTool!, { message: 'backward compat test' }) + + expect(result).toMatchObject({ + content: expect.arrayContaining([expect.objectContaining({ type: 'text', text: 'backward compat test' })]), + }) + } finally { + await client.disconnect() + } + }, 30000) + + it('handles calculator tool with complex arguments', async () => { + if (!nonTaskServerInfo) throw new Error('Non-task server not started') + + const client = createClient(nonTaskServerInfo.url, 'test-compat-client') + try { + await client.connect() + const tools = await client.listTools() + const calculatorTool = tools.find((t) => t.name === 'calculator') + expect(calculatorTool).toBeDefined() + + const result = await client.callTool(calculatorTool!, { operation: 'multiply', a: 6, b: 7 }) + + expect(result).toMatchObject({ + content: expect.arrayContaining([expect.objectContaining({ type: 'text', text: 'Result: 42' })]), + }) + } finally { + await client.disconnect() + } + }, 30000) + }) + + describe('Agent Integration with Task Tools', () => { + it('agent can use task tools in a conversation', async () => { + if (!taskServerInfo) throw new Error('Task server not started') + + const client = createClient(taskServerInfo.url, 'test-agent-task-client', {}) + try { + const model = bedrock.createModel({ maxTokens: 300 }) + const agent = new Agent({ + systemPrompt: + 'You are a helpful assistant. When asked to run a task, use the instant_task tool with the value provided by the user.', + tools: [client], + model, + }) + + const result = await agent.invoke('Please run an instant task with the value "agent test message"') + + expect(result).toBeDefined() + expect(result.stopReason).toBeDefined() + expect(hasToolUse(agent.messages, 'instant_task')).toBe(true) + expect(countToolResults(agent.messages, 'success')).toBeGreaterThan(0) + } finally { + await client.disconnect() + } + }, 60000) + + it('agent handles task tool errors gracefully', async () => { + if (!taskServerInfo) throw new Error('Task server not started') + + const client = createClient(taskServerInfo.url, 'test-agent-task-client', {}) + try { + const model = bedrock.createModel({ maxTokens: 300 }) + const agent = new Agent({ + systemPrompt: 'You are a helpful assistant. When asked to test error handling, use the failing_task tool.', + tools: [client], + model, + }) + + const result = await agent.invoke('Please use the failing_task tool to test error handling.') + + expect(result).toBeDefined() + expect(hasToolUse(agent.messages, 'failing_task')).toBe(true) + expect(countToolResults(agent.messages, 'error')).toBeGreaterThan(0) + } finally { + await client.disconnect() + } + }, 60000) + + it('agent can use multiple task tools in a multi-turn conversation', async () => { + if (!taskServerInfo) throw new Error('Task server not started') + + const client = createClient(taskServerInfo.url, 'test-agent-multi-task-client', {}) + try { + const model = bedrock.createModel({ maxTokens: 300 }) + const agent = new Agent({ + systemPrompt: + 'You are a helpful assistant. Use task tools when requested. Available tools: instant_task (quick), long_running_task (takes time).', + tools: [client], + model, + }) + + // First turn: use instant_task + await agent.invoke('Run an instant task with value "first turn"') + expect(hasToolUse(agent.messages, 'instant_task')).toBe(true) + + // Second turn: use long_running_task + await agent.invoke('Now run a long running task with message "second turn complete"') + expect(hasToolUse(agent.messages, 'long_running_task')).toBe(true) + + // Both tool results should be successful + expect(countToolResults(agent.messages, 'success')).toBeGreaterThanOrEqual(2) + } finally { + await client.disconnect() + } + }, 90000) + }) +}) diff --git a/strands-ts/test/integ/mcp/mcp.test.node.ts b/strands-ts/test/integ/mcp/mcp.test.node.ts new file mode 100644 index 0000000000..31951d5256 --- /dev/null +++ b/strands-ts/test/integ/mcp/mcp.test.node.ts @@ -0,0 +1,165 @@ +/** + * MCP Integration Tests + * + * Tests Agent integration with MCP servers using all supported transport types. + * Verifies that agents can successfully use MCP tools via the Bedrock model. + */ + +import { describe, it, expect, beforeAll, afterAll, vi } from 'vitest' +import { McpClient, Agent } from '@strands-agents/sdk' +import type { ElicitationCallback } from '@strands-agents/sdk' +import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js' +import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' +import { resolve } from 'node:path' +import { URL } from 'node:url' +import { startHTTPServer, type HttpServerInfo } from '../__fixtures__/test-mcp-server.js' +import { bedrock } from '../__fixtures__/model-providers.js' + +type TransportConfig = { + name: string + createClient: () => McpClient | Promise + cleanup?: () => Promise +} + +describe('MCP Integration Tests', () => { + const serverPath = resolve(process.cwd(), 'test/integ/__fixtures__/test-mcp-server.ts') + let httpServerInfo: HttpServerInfo | undefined + + beforeAll(async () => { + // Start HTTP server + httpServerInfo = await startHTTPServer() + }, 30000) + + afterAll(async () => { + if (httpServerInfo) { + await httpServerInfo.close() + } + }, 30000) + + const transports: TransportConfig[] = [ + { + name: 'stdio', + createClient: () => { + return new McpClient({ + applicationName: 'test-mcp-stdio', + transport: new StdioClientTransport({ + command: 'npx', + args: ['tsx', serverPath], + }), + }) + }, + }, + { + name: 'Streamable HTTP', + createClient: () => { + if (!httpServerInfo) throw new Error('HTTP server not started') + return new McpClient({ + applicationName: 'test-mcp-http', + transport: new StreamableHTTPClientTransport(new URL(httpServerInfo.url)), + }) + }, + }, + ] + + describe.each(transports)('$name transport', ({ createClient }) => { + it('agent can use multiple MCP tools in a conversation', async () => { + const client = await createClient() + const model = bedrock.createModel({ maxTokens: 300 }) + + const agent = new Agent({ + systemPrompt: + 'You are a helpful assistant. Use the echo tool to repeat messages and the calculator tool for arithmetic.', + tools: [client], + model, + }) + + // First turn: Use echo tool + await agent.invoke('Use the echo tool to say "Multi-turn test"') + + // Verify echo tool was used + const hasEchoUse = agent.messages.some((msg) => + msg.content.some((block) => block.type === 'toolUseBlock' && block.name === 'echo') + ) + expect(hasEchoUse).toBe(true) + + // Second turn: Use calculator tool in same conversation + const result = await agent.invoke('Now use the calculator tool to add 15 and 27') + + expect(result).toBeDefined() + expect(result.stopReason).toBeDefined() + + // Verify calculator tool was used + const hasCalculatorUse = agent.messages.some((msg) => + msg.content.some((block) => block.type === 'toolUseBlock' && block.name === 'calculator') + ) + expect(hasCalculatorUse).toBe(true) + }, 60000) + + it('agent handles MCP tool errors gracefully', async () => { + const client = await createClient() + const model = bedrock.createModel({ maxTokens: 200 }) + + const agent = new Agent({ + systemPrompt: 'You are a helpful assistant. If asked to test errors, use the error_tool.', + tools: [client], + model, + }) + + const result = await agent.invoke('Use the error_tool to test error handling.') + + expect(result).toBeDefined() + + // Verify the error was encountered + const hasErrorResult = agent.messages.some((msg) => + msg.content.some((block) => block.type === 'toolResultBlock' && block.status === 'error') + ) + expect(hasErrorResult).toBe(true) + }, 30000) + }) + + // Elicitation handler registration is transport-agnostic (happens in McpClient.connect), + // so a single transport suffices here. + describe('elicitation', () => { + it('agent can use MCP tool that requests elicitation', async () => { + const elicitationCallback: ElicitationCallback = vi.fn().mockResolvedValue({ + action: 'accept', + content: { confirmed: true }, + }) + + const client = new McpClient({ + applicationName: 'test-mcp-elicitation', + transport: new StdioClientTransport({ + command: 'npx', + args: ['tsx', serverPath], + }), + elicitationCallback, + }) + + const model = bedrock.createModel({ maxTokens: 300 }) + + const agent = new Agent({ + systemPrompt: 'You are a helpful assistant. Use the confirm_action tool when asked to confirm something.', + tools: [client], + model, + }) + + const result = await agent.invoke('Use the confirm_action tool to confirm "deploy to production"') + + expect(result).toBeDefined() + expect(result.stopReason).toBeDefined() + expect(elicitationCallback).toHaveBeenCalled() + + const hasConfirmUse = agent.messages.some((msg) => + msg.content.some((block) => block.type === 'toolUseBlock' && block.name === 'confirm_action') + ) + expect(hasConfirmUse).toBe(true) + + const hasSuccessResult = agent.messages.some((msg) => + msg.content.some((block) => block.type === 'toolResultBlock' && block.status === 'success') + ) + expect(hasSuccessResult).toBe(true) + + await client.disconnect() + }, 60000) + }) +}) diff --git a/strands-ts/test/integ/models/anthropic.test.ts b/strands-ts/test/integ/models/anthropic.test.ts new file mode 100644 index 0000000000..70bf76cd4b --- /dev/null +++ b/strands-ts/test/integ/models/anthropic.test.ts @@ -0,0 +1,189 @@ +import { describe, expect, it } from 'vitest' +import { Message, ImageBlock, TextBlock, CachePointBlock } from '@strands-agents/sdk' +import type { SystemContentBlock } from '@strands-agents/sdk' +import { collectIterator } from '$/sdk/__fixtures__/model-test-helpers.js' +import { loadFixture } from '../__fixtures__/test-helpers.js' +import { anthropic } from '../__fixtures__/model-providers.js' + +import yellowPngUrl from '../__resources__/yellow.png?url' + +describe.skipIf(anthropic.skip)('AnthropicModel Integration Tests', () => { + describe('Configuration', () => { + it.concurrent('respects maxTokens configuration', async () => { + const provider = anthropic.createModel({ maxTokens: 20 }) + const messages = [ + new Message({ + role: 'user', + content: [new TextBlock('Write a very long story about space exploration.')], + }), + ] + + const events = await collectIterator(provider.stream(messages)) + + const metadataEvent = events.find((e) => e.type === 'modelMetadataEvent') + expect(metadataEvent?.usage?.outputTokens).toBeLessThanOrEqual(20) + + const messageStopEvent = events.find((e) => e.type === 'modelMessageStopEvent') + expect(messageStopEvent?.stopReason).toBe('maxTokens') + }) + }) + + describe('Prompt Caching', () => { + it('uses system prompt cache on subsequent requests', async () => { + const provider = anthropic.createModel({ maxTokens: 100 }) + + const largeContext = `Context information: ${'repeat '.repeat(5000)} [${Date.now()}]` + + const cachedSystemPrompt: SystemContentBlock[] = [ + new TextBlock('You are a helpful assistant.'), + new TextBlock(largeContext), + new CachePointBlock({ cacheType: 'default' }), + ] + + const events1 = await collectIterator( + provider.stream([new Message({ role: 'user', content: [new TextBlock('Hello')] })], { + systemPrompt: cachedSystemPrompt, + }) + ) + + const metadata1 = events1.find((e) => e.type === 'modelMetadataEvent') + const writeTokens = metadata1?.usage?.cacheWriteInputTokens + if (writeTokens !== undefined) { + expect(writeTokens).toBeGreaterThan(0) + } + + const events2 = await collectIterator( + provider.stream([new Message({ role: 'user', content: [new TextBlock('Hi again')] })], { + systemPrompt: cachedSystemPrompt, + }) + ) + + const metadata2 = events2.find((e) => e.type === 'modelMetadataEvent') + const readTokens = metadata2?.usage?.cacheReadInputTokens + if (readTokens !== undefined) { + expect(readTokens).toBeGreaterThanOrEqual(0) + } + }) + + it('uses message cache points on subsequent requests', async () => { + const provider = anthropic.createModel({ maxTokens: 100 }) + const largeContext = `Context information: ${'repeat '.repeat(5000)} [${Date.now()}]` + + const messagesWithCache = (text: string): Message[] => [ + new Message({ + role: 'user', + content: [new TextBlock(largeContext), new CachePointBlock({ cacheType: 'default' }), new TextBlock(text)], + }), + ] + + const events1 = await collectIterator(provider.stream(messagesWithCache('Question 1'))) + const metadata1 = events1.find((e) => e.type === 'modelMetadataEvent') + const writeTokens = metadata1?.usage?.cacheWriteInputTokens + if (writeTokens !== undefined) { + expect(writeTokens).toBeGreaterThan(0) + } + + const events2 = await collectIterator(provider.stream(messagesWithCache('Question 2'))) + const metadata2 = events2.find((e) => e.type === 'modelMetadataEvent') + const readTokens = metadata2?.usage?.cacheReadInputTokens + if (readTokens !== undefined) { + expect(readTokens).toBeGreaterThanOrEqual(0) + } + }) + }) + + describe('Media Support', () => { + it('processes image input correctly', async () => { + const provider = anthropic.createModel({ maxTokens: 100 }) + + const imageBytes = await loadFixture(yellowPngUrl) + + const messages = [ + new Message({ + role: 'user', + content: [ + new ImageBlock({ + format: 'png', + source: { bytes: imageBytes }, + }), + new TextBlock('What color is this image? Reply with just the color name.'), + ], + }), + ] + + const events = await collectIterator(provider.stream(messages)) + + const stopEvent = events.find((e) => e.type === 'modelMessageStopEvent') + expect(stopEvent?.stopReason).toBe('endTurn') + + let fullText = '' + for (const event of events) { + if (event.type === 'modelContentBlockDeltaEvent' && event.delta.type === 'textDelta') { + fullText += event.delta.text + } + } + + expect(fullText.toLowerCase()).toContain('yellow') + }) + }) + + describe('Thinking Mode', () => { + it('emits thinking blocks when enabled', async () => { + const provider = anthropic.createModel({ + maxTokens: 4000, + params: { + thinking: { + type: 'enabled', + budget_tokens: 2048, + }, + }, + }) + + const messages = [ + new Message({ + role: 'user', + content: [new TextBlock('Explain the theory of relativity step-by-step.')], + }), + ] + + const events = await collectIterator(provider.stream(messages)) + + const thinkingEvents = events.filter( + (e) => e.type === 'modelContentBlockDeltaEvent' && e.delta.type === 'reasoningContentDelta' + ) + + if (thinkingEvents.length > 0) { + expect(thinkingEvents[0]!.type).toBe('modelContentBlockDeltaEvent') + const firstThinking = thinkingEvents[0] as any + expect(firstThinking.delta.text).toBeDefined() + } + }) + }) + + describe('countTokens', () => { + const messages = [ + new Message({ role: 'user', content: [new TextBlock('What is the capital of France? Explain in detail.')] }), + ] + const toolSpecs = [ + { + name: 'get_weather', + description: 'Get the current weather for a location', + inputSchema: { type: 'object' as const, properties: { location: { type: 'string' as const } } }, + }, + ] + + it.concurrent('should count tokens for messages only', async () => { + const model = anthropic.createModel() + const result = await model.countTokens(messages) + expect(typeof result).toBe('number') + expect(result).toBeGreaterThan(0) + }) + + it.concurrent('should return more tokens with tools and system prompt', async () => { + const model = anthropic.createModel() + const without = await model.countTokens(messages) + const withTools = await model.countTokens(messages, { toolSpecs, systemPrompt: 'Be helpful.' }) + expect(withTools).toBeGreaterThan(without) + }) + }) +}) diff --git a/strands-ts/test/integ/models/bedrock.test.node.ts b/strands-ts/test/integ/models/bedrock.test.node.ts new file mode 100644 index 0000000000..d406ed1078 --- /dev/null +++ b/strands-ts/test/integ/models/bedrock.test.node.ts @@ -0,0 +1,59 @@ +import { describe, expect, it, vi } from 'vitest' +import { bedrock } from '../__fixtures__/model-providers.js' +import { Agent } from '$/sdk/agent/agent.js' + +describe.skipIf(bedrock.skip)('BedrockModel Integration Tests', () => { + describe('Agent with String Model ID', () => { + it.concurrent('accepts string model ID and creates functional Agent', async () => { + // Create agent with string model ID + const agent = new Agent({ + model: 'global.anthropic.claude-sonnet-4-6', + printer: false, + }) + + // Invoke agent with simple prompt + const result = await agent.invoke('Say hello') + + // Verify agent works correctly + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.role).toBe('assistant') + expect(result.lastMessage.content.length).toBeGreaterThan(0) + + // Verify message contains text content + const textContent = result.lastMessage.content.find((block) => block.type === 'textBlock') + expect(textContent).toBeDefined() + expect(textContent?.text).toBeTruthy() + }) + }) + + describe('Region Configuration', () => { + it('uses AWS_REGION environment variable when set', async () => { + // Use vitest to stub the environment variable + vi.stubEnv('AWS_REGION', 'eu-central-1') + + const provider = bedrock.createModel({ + maxTokens: 50, + }) + + // Validate AWS_REGION environment variable is used + // Making an actual request doesn't guarantee the correct region is being used + const regionResult = await provider['_client'].config.region() + expect(regionResult).toBe('eu-central-1') + }) + + it('explicit region takes precedence over environment variable', async () => { + // Use vitest to stub the environment variable + vi.stubEnv('AWS_REGION', 'eu-west-1') + + const provider = bedrock.createModel({ + region: 'ap-southeast-2', + maxTokens: 50, + }) + + // Validate explicit region takes precedence over environment variable + // Making an actual request doesn't guarantee the correct region is being used + const regionResult = await provider['_client'].config.region() + expect(regionResult).toBe('ap-southeast-2') + }) + }) +}) diff --git a/strands-ts/test/integ/models/bedrock.test.ts b/strands-ts/test/integ/models/bedrock.test.ts new file mode 100644 index 0000000000..944d90c194 --- /dev/null +++ b/strands-ts/test/integ/models/bedrock.test.ts @@ -0,0 +1,771 @@ +import { beforeAll, describe, expect, it, vi } from 'vitest' +import { + Agent, + Message, + NullConversationManager, + SlidingWindowConversationManager, + TextBlock, + FunctionTool, + CachePointBlock, + ImageBlock, +} from '@strands-agents/sdk' +import type { SystemContentBlock, ModelRedactionEvent } from '@strands-agents/sdk' + +import { collectIterator } from '$/sdk/__fixtures__/model-test-helpers.js' +import { bedrock } from '../__fixtures__/model-providers.js' +import { loadFixture } from '../__fixtures__/test-helpers.js' +import yellowPngUrl from '../__resources__/yellow.png?url' +import { + BedrockClient, + CreateGuardrailCommand, + GetGuardrailCommand, + ListGuardrailsCommand, +} from '@aws-sdk/client-bedrock' +import { inject } from 'vitest' + +describe.skipIf(bedrock.skip)('BedrockModel Integration Tests', () => { + describe('Streaming', () => { + describe('Configuration', () => { + it.concurrent('respects maxTokens configuration', async () => { + const provider = bedrock.createModel({ maxTokens: 20 }) + const messages = [ + new Message({ + role: 'user', + content: [new TextBlock('Write a long story about dragons.')], + }), + ] + + const events = await collectIterator(provider.stream(messages)) + + const metadataEvent = events.find((e) => e.type === 'modelMetadataEvent') + expect(metadataEvent?.usage?.outputTokens).toBeLessThanOrEqual(20) + + const messageStopEvent = events.find((e) => e.type === 'modelMessageStopEvent') + expect(messageStopEvent?.stopReason).toBe('maxTokens') + }) + + it.concurrent('uses system prompt cache on subsequent requests', async () => { + const provider = bedrock.createModel({ + modelId: 'global.anthropic.claude-sonnet-4-5-20250929-v1:0', + maxTokens: 100, + }) + const largeContext = `Context information: ${'hello '.repeat(2000)} [test-${Date.now()}-${Math.random()}]` + const cachedSystemPrompt: SystemContentBlock[] = [ + new TextBlock('You are a helpful assistant.'), + new TextBlock(largeContext), + new CachePointBlock({ cacheType: 'default' }), + ] + + // First request - creates cache + const events1 = await collectIterator( + provider.stream([new Message({ role: 'user', content: [new TextBlock('Say hello')] })], { + systemPrompt: cachedSystemPrompt, + }) + ) + const metadata1 = events1.find((e) => e.type === 'modelMetadataEvent') + expect(metadata1?.usage?.cacheWriteInputTokens).toBeGreaterThan(0) + + // Second request - should use cache + const events2 = await collectIterator( + provider.stream([new Message({ role: 'user', content: [new TextBlock('Say goodbye')] })], { + systemPrompt: cachedSystemPrompt, + }) + ) + const metadata2 = events2.find((e) => e.type === 'modelMetadataEvent') + expect(metadata2?.usage?.cacheReadInputTokens).toBeGreaterThan(0) + }) + + it.concurrent('uses message cache points on subsequent requests', async () => { + const provider = bedrock.createModel({ + modelId: 'global.anthropic.claude-sonnet-4-5-20250929-v1:0', + maxTokens: 100, + }) + const largeContext = `Context information: ${'hello '.repeat(2000)} [test-${Date.now()}-${Math.random()}]` + const messagesWithCachePoint = (text: string): Message[] => [ + new Message({ + role: 'user', + content: [new TextBlock(largeContext), new CachePointBlock({ cacheType: 'default' }), new TextBlock(text)], + }), + ] + + // First request - creates cache + const events1 = await collectIterator(provider.stream(messagesWithCachePoint('Say hello'))) + const metadata1 = events1.find((e) => e.type === 'modelMetadataEvent') + expect(metadata1?.usage?.cacheWriteInputTokens).toBeGreaterThan(0) + + // Second request - should use cache + const events2 = await collectIterator(provider.stream(messagesWithCachePoint('Say goodbye'))) + const metadata2 = events2.find((e) => e.type === 'modelMetadataEvent') + expect(metadata2?.usage?.cacheReadInputTokens).toBeGreaterThan(0) + }) + + it.concurrent('uses cacheConfig to automatically inject cache points in tools and messages', async () => { + const provider = bedrock.createModel({ + modelId: 'global.anthropic.claude-sonnet-4-5-20250929-v1:0', + maxTokens: 100, + cacheConfig: { strategy: 'auto' }, + }) + const largeContext = `Context information: ${'hello '.repeat(2000)} [test-${Date.now()}-${Math.random()}]` + + const toolSpecs = [ + { + name: 'lookup', + description: 'Look up information. '.repeat(100), + inputSchema: { type: 'object' as const, properties: { query: { type: 'string' as const } } }, + }, + ] + + const messages = [new Message({ role: 'user', content: [new TextBlock(largeContext)] })] + + // First request - writes to cache + const events1 = await collectIterator(provider.stream(messages, { toolSpecs })) + const metadata1 = events1.find((e) => e.type === 'modelMetadataEvent') + expect(metadata1?.usage?.cacheWriteInputTokens).toBeGreaterThan(0) + + // Second request - identical content, should read from cache + const events2 = await collectIterator(provider.stream(messages, { toolSpecs })) + const metadata2 = events2.find((e) => e.type === 'modelMetadataEvent') + expect(metadata2?.usage?.cacheReadInputTokens).toBeGreaterThan(0) + }) + + it.concurrent( + 'uses cacheConfig with explicit anthropic strategy for application inference profiles', + async () => { + const provider = bedrock.createModel({ + modelId: 'us.anthropic.claude-sonnet-4-5-20250929-v1:0', + maxTokens: 100, + cacheConfig: { strategy: 'anthropic' }, + }) + const largeContext = `Context information: ${'hello '.repeat(2000)} [test-${Date.now()}-${Math.random()}]` + + const messages = [new Message({ role: 'user', content: [new TextBlock(largeContext)] })] + + // First request - writes to cache + const events1 = await collectIterator(provider.stream(messages)) + const metadata1 = events1.find((e) => e.type === 'modelMetadataEvent') + expect(metadata1?.usage?.cacheWriteInputTokens).toBeGreaterThan(0) + + // Second request - identical content, should read from cache + const events2 = await collectIterator(provider.stream(messages)) + const metadata2 = events2.find((e) => e.type === 'modelMetadataEvent') + expect(metadata2?.usage?.cacheReadInputTokens).toBeGreaterThan(0) + } + ) + }) + + describe('Error Handling', () => { + it.concurrent('handles invalid model ID gracefully', async () => { + const provider = bedrock.createModel({ modelId: 'invalid-model-id-that-does-not-exist' }) + const messages = [new Message({ role: 'user', content: [new TextBlock('Hello')] })] + await expect(collectIterator(provider.stream(messages))).rejects.toThrow() + }) + }) + }) + + describe('Agent with Conversation Manager', () => { + it('manages conversation history with SlidingWindowConversationManager', async () => { + const agent = new Agent({ + model: bedrock.createModel({ maxTokens: 100 }), + conversationManager: new SlidingWindowConversationManager({ windowSize: 4 }), + }) + + // First exchange + await agent.invoke('Count from 1 to 1.') + expect(agent.messages).toHaveLength(2) // user + assistant + + // Second exchange + await agent.invoke('Count from 2 to 2.') + expect(agent.messages).toHaveLength(4) // 2 user + 2 assistant + + // Third exchange - should trigger sliding window + await agent.invoke('Count from 3 to 3.') + + // Should maintain window size of 4 messages + expect(agent.messages).toHaveLength(4) + }, 30000) + + it('throws ContextWindowOverflowError with NullConversationManager', async () => { + const agent = new Agent({ + model: bedrock.createModel({ maxTokens: 50 }), + conversationManager: new NullConversationManager(), + }) + + // Generate a message that would require context management + const longPrompt = 'Please write a very detailed explanation of ' + 'many topics '.repeat(50) + + // This should throw since NullConversationManager doesn't handle overflow + await expect(agent.invoke(longPrompt)).rejects.toThrow() + }, 30000) + }) + + describe('Region Configuration', () => { + it('uses explicit region when provided', async () => { + const provider = bedrock.createModel({ + region: 'us-east-1', + maxTokens: 50, + }) + + // Validate region configuration by checking config.region() directly + // Making an actual request doesn't guarantee the correct region is being used + const regionResult = await provider['_client'].config.region() + expect(regionResult).toBe('us-east-1') + }) + + it('uses region from clientConfig when provided', async () => { + const provider = bedrock.createModel({ + clientConfig: { region: 'ap-northeast-1' }, + maxTokens: 50, + }) + + // Validate clientConfig region is used + // Making an actual request doesn't guarantee the correct region is being used + const regionResult = await provider['_client'].config.region() + expect(regionResult).toBe('ap-northeast-1') + }) + + it('defaults to us-west-2 when no region provided and AWS SDK does not resolve one', async () => { + // Use vitest to stub environment variables + vi.stubEnv('AWS_REGION', undefined) + vi.stubEnv('AWS_DEFAULT_REGION', undefined) + // Point config and credential files to null values + vi.stubEnv('AWS_CONFIG_FILE', '/dev/null') + vi.stubEnv('AWS_SHARED_CREDENTIALS_FILE', '/dev/null') + + const provider = bedrock.createModel({ + maxTokens: 50, + }) + + // Validate region defaults to us-west-2 + // Making an actual request doesn't guarantee the correct region is being used + const regionResult = await provider['_client'].config.region() + expect(regionResult).toBe('us-west-2') + + // ensure that invocation works + await collectIterator( + provider.stream([ + Message.fromMessageData({ + role: 'user', + content: [new TextBlock('say hi')], + }), + ]) + ) + }) + + it('uses region from clientConfig when provided', async () => { + const provider = bedrock.createModel({ + clientConfig: { region: 'ap-northeast-1' }, + maxTokens: 50, + }) + + // Validate clientConfig region is used + // Making an actual request doesn't guarantee the correct region is being used + const regionResult = await provider['_client'].config.region() + expect(regionResult).toBe('ap-northeast-1') + }) + }) + + describe('Thinking Mode with Tools', () => { + it('handles thinking mode with tool use', async () => { + const bedrockModel = bedrock.createModel({ + modelId: 'global.anthropic.claude-sonnet-4-6', + additionalRequestFields: { + thinking: { + type: 'enabled', + budget_tokens: 1024, + }, + }, + maxTokens: 2048, + }) + + const testTool = new FunctionTool({ + name: 'testTool', + description: 'Test description', + inputSchema: { type: 'object' }, + callback: (): string => 'result', + }) + + // Create agent with thinking mode and tool + const agent = new Agent({ + model: bedrockModel, + tools: [testTool], + printer: false, + }) + + // Invoke agent with a prompt that triggers tool use + const result = await agent.invoke('Use the testTool with the message "Hello World"') + + // Verify the agent completed successfully + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.role).toBe('assistant') + expect(result.lastMessage.content.length).toBeGreaterThan(0) + + // Verify the tool was used + const toolUseMessage = agent.messages.find((msg) => msg.content.some((block) => block.type === 'toolUseBlock')) + expect(toolUseMessage).toBeDefined() + + // Verify the tool result is in the history + const toolResultMessage = agent.messages.find((msg) => + msg.content.some((block) => block.type === 'toolResultBlock') + ) + expect(toolResultMessage).toBeDefined() + }, 30000) + }) + + describe('Guardrails', () => { + const BLOCKED_INPUT = 'BLOCKED_INPUT' + const BLOCKED_OUTPUT = 'BLOCKED_OUTPUT' + const GUARDRAIL_NAME = 'test-guardrail-block-cactus' + + let GUARDRAIL_ID: string | undefined + + /** + * Gets the guardrail ID by name if it exists + */ + async function getGuardrailId(client: BedrockClient, guardrailName: string): Promise { + const response = await client.send(new ListGuardrailsCommand({})) + const guardrail = response.guardrails?.find((g) => g.name === guardrailName) + return guardrail?.id + } + + /** + * Waits for the guardrail to become active + */ + async function waitForGuardrailActive( + client: BedrockClient, + guardrailId: string, + maxAttempts = 10, + delayMs = 5000 + ): Promise { + for (let i = 0; i < maxAttempts; i++) { + const response = await client.send(new GetGuardrailCommand({ guardrailIdentifier: guardrailId })) + const status = response.status + + if (status === 'READY') { + console.log(`Guardrail ${guardrailId} is now active`) + return + } + + console.log(`Waiting for guardrail to become active. Current status: ${status}`) + await new Promise((resolve) => setTimeout(resolve, delayMs)) + } + + throw new Error(`Guardrail did not become active within ${(maxAttempts * delayMs) / 1000} seconds`) + } + + /** + * Creates or retrieves the test guardrail + */ + async function setupGuardrail(): Promise { + const credentials = inject('provider-bedrock')?.credentials + if (!credentials) { + throw new Error('No Bedrock credentials provided') + } + + const client = new BedrockClient({ region: 'us-east-1', credentials }) + + // Check if guardrail already exists + let guardrailId = await getGuardrailId(client, GUARDRAIL_NAME) + + if (guardrailId) { + console.log(`Guardrail ${GUARDRAIL_NAME} already exists with ID: ${guardrailId}`) + } else { + console.log(`Creating guardrail ${GUARDRAIL_NAME}`) + const response = await client.send( + new CreateGuardrailCommand({ + name: GUARDRAIL_NAME, + description: 'Testing Guardrail', + wordPolicyConfig: { + wordsConfig: [ + { + text: 'CACTUS', + }, + ], + }, + blockedInputMessaging: BLOCKED_INPUT, + blockedOutputsMessaging: BLOCKED_OUTPUT, + }) + ) + guardrailId = response.guardrailId + if (!guardrailId) { + throw new Error('Failed to create guardrail: no ID returned') + } + console.log(`Created test guardrail with ID: ${guardrailId}`) + await waitForGuardrailActive(client, guardrailId) + } + + if (!guardrailId) { + throw new Error('Failed to get or create guardrail') + } + + return guardrailId + } + + beforeAll(async () => { + GUARDRAIL_ID = await setupGuardrail() + }, 60000) + + describe('Input Intervention', () => { + it.each(['enabled', 'enabled_full'] as const)( + 'blocks input and redacts message with trace=%s', + async (guardrailTrace) => { + const model = bedrock.createModel({ + region: 'us-east-1', + guardrailConfig: { + guardrailIdentifier: GUARDRAIL_ID!, + guardrailVersion: 'DRAFT', + trace: guardrailTrace, + redaction: { + input: true, + inputMessage: 'Redacted.', + }, + }, + }) + + const agent = new Agent({ + model, + systemPrompt: 'You are a helpful assistant.', + printer: false, + }) + + const response1 = await agent.invoke('CACTUS') + const response2 = await agent.invoke('Hello!') + + expect(response1.stopReason).toBe('guardrailIntervened') + expect(response1.toString().trim()).toBe(BLOCKED_INPUT) + expect(response2.stopReason).not.toBe('guardrailIntervened') + expect(response2.toString().trim()).not.toBe(BLOCKED_INPUT) + expect(agent.messages[0]?.content[0]?.type).toBe('textBlock') + const firstBlock = agent.messages[0]?.content[0] + if (firstBlock?.type === 'textBlock') { + expect(firstBlock.text).toBe('Redacted.') + } + }, + 30000 + ) + }) + + describe('Output Intervention', () => { + it.each(['sync', 'async'] as const)( + 'blocks output without redaction in %s mode', + async (processingMode) => { + const model = bedrock.createModel({ + modelId: 'global.anthropic.claude-sonnet-4-5-20250929-v1:0', + region: 'us-east-1', + guardrailConfig: { + guardrailIdentifier: GUARDRAIL_ID!, + guardrailVersion: 'DRAFT', + streamProcessingMode: processingMode, + redaction: { + output: false, + }, + }, + }) + + const agent = new Agent({ + model, + systemPrompt: 'When asked to say the word, say CACTUS.', + printer: false, + }) + + const response1 = await agent.invoke('Say the word.') + const response2 = await agent.invoke('Hello!') + + expect(response1.stopReason).toBe('guardrailIntervened') + + if (processingMode === 'sync') { + // In sync mode, we can reliably check the response content + expect(response1.toString()).toContain(BLOCKED_OUTPUT) + expect(response2.stopReason).not.toBe('guardrailIntervened') + expect(response2.toString()).not.toContain(BLOCKED_OUTPUT) + } else { + // In async mode, either: + // - CACTUS was returned and blocked by input guardrail on next turn, or + // - CACTUS was blocked in response1, allowing normal response2 + const cactusCaughtByInputGuardrail = response2.toString().includes(BLOCKED_INPUT) + const cactusBlockedAllowsNextResponse = + !response2.toString().includes(BLOCKED_OUTPUT) && response2.stopReason !== 'guardrailIntervened' + expect(cactusCaughtByInputGuardrail || cactusBlockedAllowsNextResponse).toBe(true) + } + }, + 30000 + ) + + it.each([ + ['sync', 'enabled'], + ['sync', 'enabled_full'], + ['async', 'enabled'], + ['async', 'enabled_full'], + ] as const)( + 'blocks output with redaction in %s mode with trace=%s', + async (processingMode, guardrailTrace) => { + const REDACT_MESSAGE = 'Redacted.' + const model = bedrock.createModel({ + region: 'us-east-1', + guardrailConfig: { + guardrailIdentifier: GUARDRAIL_ID!, + guardrailVersion: 'DRAFT', + streamProcessingMode: processingMode, + trace: guardrailTrace, + redaction: { + output: true, + outputMessage: REDACT_MESSAGE, + }, + }, + temperature: 0, // Deterministic responses + }) + + const agent = new Agent({ + model, + systemPrompt: 'When asked to say the word, say CACTUS. Otherwise, respond normally.', + printer: false, + }) + + const response1 = await agent.invoke('Say the word.') + // Use unrelated prompt to avoid model volunteering CACTUS + const response2 = await agent.invoke('What is 2+2? Reply with only the number.') + + expect(response1.stopReason).toBe('guardrailIntervened') + + if (processingMode === 'sync') { + expect(response1.toString()).toContain(REDACT_MESSAGE) + expect(response2.stopReason).not.toBe('guardrailIntervened') + expect(response2.toString()).not.toContain(REDACT_MESSAGE) + } else { + // In async mode, either: + // - CACTUS was returned and blocked by input guardrail on next turn, or + // - CACTUS was blocked in response1, allowing normal response2 + const cactusCaughtByInputGuardrail = response2.toString().includes(BLOCKED_INPUT) + const cactusBlockedAllowsNextResponse = + !response2.toString().includes(REDACT_MESSAGE) && response2.stopReason !== 'guardrailIntervened' + expect(cactusCaughtByInputGuardrail || cactusBlockedAllowsNextResponse).toBe(true) + } + }, + 30000 + ) + + it('captures redactedContent from modelOutput in sync mode', async () => { + const REDACT_MESSAGE = 'Content blocked.' + const model = bedrock.createModel({ + region: 'us-east-1', + guardrailConfig: { + guardrailIdentifier: GUARDRAIL_ID!, + guardrailVersion: 'DRAFT', + streamProcessingMode: 'sync', + trace: 'enabled_full', // Need full trace to get modelOutput + redaction: { + output: true, + outputMessage: REDACT_MESSAGE, + }, + }, + temperature: 0, + }) + + const messages = [new Message({ role: 'user', content: [new TextBlock('Say CACTUS.')] })] + + // Collect streaming events to check for redactedContent + const events: any[] = [] + for await (const event of model.stream(messages)) { + events.push(event) + } + + // Find the ModelRedactionEvent with outputRedaction + const redactEvent = events.find((e) => e.type === 'modelRedactionEvent' && e.outputRedaction) as + | ModelRedactionEvent + | undefined + + expect(redactEvent).toBeDefined() + expect(redactEvent?.outputRedaction?.replaceContent).toBe(REDACT_MESSAGE) + + // In sync mode with full trace, we should get the original content + // The exact content may vary, but if blocked, redactedContent should be present + if (redactEvent?.outputRedaction?.redactedContent) { + expect(redactEvent.outputRedaction.redactedContent).toContain('CACTUS') + } + }, 30000) + }) + + describe('Tool Result Redaction', () => { + it.each(['sync', 'async'] as const)( + 'properly redacts tool result in %s mode', + async (processingMode) => { + const INPUT_REDACT_MESSAGE = 'Input redacted.' + const OUTPUT_REDACT_MESSAGE = 'Output redacted.' + + const model = bedrock.createModel({ + region: 'us-east-1', + guardrailConfig: { + guardrailIdentifier: GUARDRAIL_ID!, + guardrailVersion: 'DRAFT', + streamProcessingMode: processingMode, + redaction: { + input: true, + inputMessage: INPUT_REDACT_MESSAGE, + output: true, + outputMessage: OUTPUT_REDACT_MESSAGE, + }, + }, + }) + + const listUsers = new FunctionTool({ + name: 'list_users', + description: 'List my users', + inputSchema: { type: 'object', properties: {} }, + callback: async () => { + return '[{"name": "Jerry Merry"}, {"name": "Mr. CACTUS"}]' + }, + }) + + const agent = new Agent({ + model, + systemPrompt: 'You are a helpful assistant.', + tools: [listUsers], + printer: false, + }) + + const response1 = await agent.invoke('List my users.') + const response2 = await agent.invoke('Thank you!') + + /* + * Message sequence: + * 0 (user): request1 + * 1 (assistant): reasoning + tool call + * 2 (user): tool result + * 3 (assistant): response1 -> output guardrail intervenes + * 4 (user): request2 + * 5 (assistant): response2 + * + * Guardrail intervened on output in message 3 will cause + * the redaction of the preceding input (message 2) and message 3. + */ + + expect(response1.stopReason).toBe('guardrailIntervened') + + if (processingMode === 'sync') { + // In sync mode the guardrail processing is blocking + expect(response1.toString()).toContain(OUTPUT_REDACT_MESSAGE) + expect(response2.toString()).not.toContain(OUTPUT_REDACT_MESSAGE) + } + + // In both sync and async with output redaction: + // 1. Content should be properly redacted so response2 is not blocked + expect(response2.stopReason).not.toBe('guardrailIntervened') + + // 2. Tool result block should be redacted properly + const toolUseMessage = agent.messages[1] + const toolResultMessage = agent.messages[2] + + expect(toolUseMessage).toBeDefined() + expect(toolResultMessage).toBeDefined() + + const toolUseBlock = toolUseMessage?.content.find((b) => b.type === 'toolUseBlock') + const toolResultBlock = toolResultMessage?.content.find((b) => b.type === 'toolResultBlock') + + expect(toolUseBlock).toBeDefined() + expect(toolResultBlock).toBeDefined() + + if (toolUseBlock?.type === 'toolUseBlock' && toolResultBlock?.type === 'toolResultBlock') { + expect(toolResultBlock.toolUseId).toBe(toolUseBlock.toolUseId) + const firstContent = toolResultBlock.content[0] + expect(firstContent).toBeDefined() + if (firstContent?.type === 'textBlock') { + expect((firstContent as TextBlock).text).toBe(INPUT_REDACT_MESSAGE) + } + } + }, + 30000 + ) + }) + + describe('guardLatestUserMessage', () => { + it('allows conversation when latest user message is clean even if earlier messages would trigger guardrails', async () => { + // Load test image + const imageBytes = await loadFixture(yellowPngUrl) + + // Create model with guardLatestUserMessage enabled + const model = bedrock.createModel({ + region: 'us-east-1', + guardrailConfig: { + guardrailIdentifier: GUARDRAIL_ID!, + guardrailVersion: 'DRAFT', + guardLatestUserMessage: true, + }, + }) + + // Create agent with previous messages that CONTAIN blocked content (CACTUS) + // When guardLatestUserMessage is enabled, these earlier messages should NOT trigger the guardrail + const agent = new Agent({ + model, + printer: false, + systemPrompt: 'You are a helpful assistant.', + messages: [ + new Message({ + role: 'user', + content: [ + new TextBlock('Dont Say CACTUS'), + new ImageBlock({ format: 'png', source: { bytes: imageBytes } }), + ], + }), + new Message({ role: 'assistant', content: [new TextBlock('Hello!')] }), + ], + }) + + // Send a clean message - should NOT trigger guardrail because only the latest message is evaluated + const response = await agent.invoke('Hello!') + + expect(response.stopReason).not.toBe('guardrailIntervened') + }, 30000) + + it('blocks conversation when latest user message contains blocked content', async () => { + // Create model with guardLatestUserMessage enabled + const model = bedrock.createModel({ + region: 'us-east-1', + guardrailConfig: { + guardrailIdentifier: GUARDRAIL_ID!, + guardrailVersion: 'DRAFT', + guardLatestUserMessage: true, + }, + }) + + // Send message with blocked content + const agent = new Agent({ + model, + printer: false, + systemPrompt: 'You are a helpful assistant.', + }) + + const response = await agent.invoke('Tell me about CACTUS plants') + + // The guardrail should have intervened + expect(response.stopReason).toBe('guardrailIntervened') + expect(response.toString()).toContain(BLOCKED_INPUT) + }, 30000) + }) + }) + + describe('countTokens', () => { + const messages = [ + new Message({ role: 'user', content: [new TextBlock('What is the capital of France? Explain in detail.')] }), + ] + const toolSpecs = [ + { + name: 'get_weather', + description: 'Get the current weather for a location', + inputSchema: { type: 'object' as const, properties: { location: { type: 'string' as const } } }, + }, + ] + + it.concurrent('should count tokens for messages only', async () => { + const model = bedrock.createModel() + const result = await model.countTokens(messages) + expect(typeof result).toBe('number') + expect(result).toBeGreaterThan(0) + }) + + it.concurrent('should return more tokens with tools and system prompt', async () => { + const model = bedrock.createModel() + const without = await model.countTokens(messages) + const withTools = await model.countTokens(messages, { toolSpecs, systemPrompt: 'Be helpful.' }) + expect(withTools).toBeGreaterThan(without) + }) + }) +}) diff --git a/strands-ts/test/integ/models/google.test.ts b/strands-ts/test/integ/models/google.test.ts new file mode 100644 index 0000000000..fd901a5276 --- /dev/null +++ b/strands-ts/test/integ/models/google.test.ts @@ -0,0 +1,160 @@ +import { describe, expect, it } from 'vitest' +import { Message, TextBlock } from '@strands-agents/sdk' +import type { ModelStreamEvent } from '$/sdk/models/streaming.js' + +import { collectIterator } from '$/sdk/__fixtures__/model-test-helpers.js' + +import { gemini } from '../__fixtures__/model-providers.js' + +/** + * Gemini-specific integration tests. + * + * Tests for functionality covered by agent.test.ts (system prompts, conversation context, + * media content, reasoning, basic agent usage) are intentionally omitted here to avoid duplication. + * This file focuses on low-level model provider behavior specific to Gemini. + */ +describe.skipIf(gemini.skip)('GoogleModel Integration Tests', () => { + describe('Streaming', () => { + describe('Configuration', () => { + it.concurrent('respects temperature configuration', async () => { + const provider = gemini.createModel({ + modelId: 'gemini-2.0-flash', + params: { temperature: 0, maxOutputTokens: 50 }, + }) + + const messages: Message[] = [ + new Message({ + role: 'user', + content: [new TextBlock('Say "hello world" exactly.')], + }), + ] + + const events1 = await collectIterator(provider.stream(messages)) + const events2 = await collectIterator(provider.stream(messages)) + + let text1 = '' + let text2 = '' + + for (const event of events1) { + if (event.type === 'modelContentBlockDeltaEvent' && event.delta.type === 'textDelta') { + text1 += event.delta.text + } + } + + for (const event of events2) { + if (event.type === 'modelContentBlockDeltaEvent' && event.delta.type === 'textDelta') { + text2 += event.delta.text + } + } + + expect(text1.length).toBeGreaterThan(0) + expect(text2.length).toBeGreaterThan(0) + expect(text1.toLowerCase()).toContain('hello') + expect(text2.toLowerCase()).toContain('hello') + }) + }) + + describe('Error Handling', () => { + it.concurrent('handles invalid model ID gracefully', async () => { + const provider = gemini.createModel({ + modelId: 'invalid-model-id-that-does-not-exist-xyz', + }) + + const messages: Message[] = [ + new Message({ + role: 'user', + content: [new TextBlock('Hello')], + }), + ] + + await expect(collectIterator(provider.stream(messages))).rejects.toThrow(/not found/i) + }) + }) + + describe('Content Block Lifecycle', () => { + it.concurrent('emits complete content block lifecycle events', async () => { + const provider = gemini.createModel({ + modelId: 'gemini-2.0-flash', + params: { maxOutputTokens: 50 }, + }) + + const messages: Message[] = [ + new Message({ + role: 'user', + content: [new TextBlock('Say hello.')], + }), + ] + + const events = await collectIterator(provider.stream(messages)) + + const startEvents = events.filter((e) => e.type === 'modelContentBlockStartEvent') + const deltaEvents = events.filter((e) => e.type === 'modelContentBlockDeltaEvent') + const stopEvents = events.filter((e) => e.type === 'modelContentBlockStopEvent') + + expect(startEvents.length).toBeGreaterThan(0) + expect(deltaEvents.length).toBeGreaterThan(0) + expect(stopEvents.length).toBeGreaterThan(0) + + const startIndex = events.findIndex((e) => e.type === 'modelContentBlockStartEvent') + const firstDeltaIndex = events.findIndex((e) => e.type === 'modelContentBlockDeltaEvent') + expect(startIndex).toBeLessThan(firstDeltaIndex) + + const stopIndex = events.findIndex((e) => e.type === 'modelContentBlockStopEvent') + const lastDeltaIndex = events + .map((e, i) => (e.type === 'modelContentBlockDeltaEvent' ? i : -1)) + .filter((i) => i !== -1) + .pop()! + expect(stopIndex).toBeGreaterThan(lastDeltaIndex) + }) + }) + + describe('Stop Reasons', () => { + it.concurrent('returns endTurn stop reason for natural completion', async () => { + const provider = gemini.createModel({ + modelId: 'gemini-2.0-flash', + params: { maxOutputTokens: 100 }, + }) + + const messages: Message[] = [ + new Message({ + role: 'user', + content: [new TextBlock('Say hi.')], + }), + ] + + const events = await collectIterator(provider.stream(messages)) + + const messageStopEvent = events.find((e) => e.type === 'modelMessageStopEvent') + expect(messageStopEvent).toBeDefined() + expect(messageStopEvent?.stopReason).toBe('endTurn') + }) + }) + }) + + describe('countTokens', () => { + const messages = [ + new Message({ role: 'user', content: [new TextBlock('What is the capital of France? Explain in detail.')] }), + ] + const toolSpecs = [ + { + name: 'get_weather', + description: 'Get the current weather for a location', + inputSchema: { type: 'object' as const, properties: { location: { type: 'string' as const } } }, + }, + ] + + it.concurrent('should count tokens for messages only', async () => { + const model = gemini.createModel() + const result = await model.countTokens(messages) + expect(typeof result).toBe('number') + expect(result).toBeGreaterThan(0) + }) + + it.concurrent('should return more tokens with tools and system prompt', async () => { + const model = gemini.createModel() + const without = await model.countTokens(messages) + const withTools = await model.countTokens(messages, { toolSpecs, systemPrompt: 'Be helpful.' }) + expect(withTools).toBeGreaterThan(without) + }) + }) +}) diff --git a/strands-ts/test/integ/models/openai/chat.test.ts b/strands-ts/test/integ/models/openai/chat.test.ts new file mode 100644 index 0000000000..f5f272c276 --- /dev/null +++ b/strands-ts/test/integ/models/openai/chat.test.ts @@ -0,0 +1,212 @@ +import { describe, expect, it } from 'vitest' +import type { ToolSpec } from '@strands-agents/sdk' +import { Message, TextBlock } from '@strands-agents/sdk' + +import { collectIterator } from '$/sdk/__fixtures__/model-test-helpers.js' + +import { openai } from '../../__fixtures__/model-providers.js' + +describe.skipIf(openai.skip)('OpenAIModel Integration Tests', () => { + describe('Configuration', () => { + it.concurrent('respects maxTokens configuration', async () => { + const provider = openai.createModel({ + modelId: 'gpt-5.4-mini', + maxTokens: 20, // Very small limit + }) + + const messages: Message[] = [ + new Message({ + role: 'user', + content: [new TextBlock('Write a long story about dragons.')], + }), + ] + + const events = await collectIterator(provider.stream(messages)) + + // Check metadata for token usage + const metadataEvent = events.find((e) => e.type === 'modelMetadataEvent') + expect(metadataEvent?.usage?.outputTokens).toBeLessThanOrEqual(25) // Allow small buffer + + // Check that stop reason is maxTokens + const messageStopEvent = events.find((e) => e.type === 'modelMessageStopEvent') + expect(messageStopEvent?.stopReason).toBe('maxTokens') + }) + + it.concurrent('respects temperature configuration', async () => { + const provider = openai.createModel({ + modelId: 'gpt-5.4-mini', + temperature: 0, // Deterministic + maxTokens: 50, + }) + + const messages: Message[] = [ + new Message({ + role: 'user', + content: [new TextBlock('Say "hello world" exactly.')], + }), + ] + + const events1 = await collectIterator(provider.stream(messages)) + const events2 = await collectIterator(provider.stream(messages)) + + // Collect text from both runs + let text1 = '' + let text2 = '' + + for (const event of events1) { + if (event.type === 'modelContentBlockDeltaEvent' && event.delta.type === 'textDelta') { + text1 += event.delta.text + } + } + + for (const event of events2) { + if (event.type === 'modelContentBlockDeltaEvent' && event.delta.type === 'textDelta') { + text2 += event.delta.text + } + } + + // With temperature=0, responses should be very similar or identical + expect(text1.length).toBeGreaterThan(0) + expect(text2.length).toBeGreaterThan(0) + // Both should contain "hello" in some form + expect(text1.toLowerCase()).toContain('hello') + expect(text2.toLowerCase()).toContain('hello') + }) + }) + + describe('Error Handling', () => { + it.concurrent('handles invalid model ID gracefully', async () => { + const provider = openai.createModel({ + modelId: 'invalid-model-id-that-does-not-exist-xyz', + }) + + const messages: Message[] = [ + new Message({ + role: 'user', + content: [new TextBlock('Hello')], + }), + ] + + // Should throw an error (OpenAI will reject the invalid model) + await expect(async () => { + for await (const _event of provider.stream(messages)) { + throw Error('Should not get here') + } + }).rejects.toThrow() + }) + }) + + describe('Content Block Lifecycle', () => { + it.concurrent('emits complete content block lifecycle events', async () => { + const provider = openai.createModel({ + modelId: 'gpt-5.4-mini', + maxTokens: 50, + }) + + const messages: Message[] = [ + new Message({ + role: 'user', + content: [new TextBlock('Say hello.')], + }), + ] + + const events = await collectIterator(provider.stream(messages)) + + // Verify complete lifecycle: start -> delta(s) -> stop + const startEvents = events.filter((e) => e.type === 'modelContentBlockStartEvent') + const deltaEvents = events.filter((e) => e.type === 'modelContentBlockDeltaEvent') + const stopEvents = events.filter((e) => e.type === 'modelContentBlockStopEvent') + + expect(startEvents.length).toBeGreaterThan(0) + expect(deltaEvents.length).toBeGreaterThan(0) + expect(stopEvents.length).toBeGreaterThan(0) + + // Start should come before delta + const startIndex = events.findIndex((e) => e.type === 'modelContentBlockStartEvent') + const firstDeltaIndex = events.findIndex((e) => e.type === 'modelContentBlockDeltaEvent') + expect(startIndex).toBeLessThan(firstDeltaIndex) + + // Stop should come after all deltas + const stopIndex = events.findIndex((e) => e.type === 'modelContentBlockStopEvent') + const lastDeltaIndex = events + .map((e, i) => (e.type === 'modelContentBlockDeltaEvent' ? i : -1)) + .filter((i) => i !== -1) + .pop()! + expect(stopIndex).toBeGreaterThan(lastDeltaIndex) + }) + }) + + describe('Stop Reasons', () => { + it.concurrent('returns endTurn stop reason for natural completion', async () => { + const provider = openai.createModel({ + modelId: 'gpt-5.4-mini', + maxTokens: 100, + }) + + const messages: Message[] = [ + new Message({ + role: 'user', + content: [new TextBlock('Say hi.')], + }), + ] + + const events = await collectIterator(provider.stream(messages)) + + const messageStopEvent = events.find((e) => e.type === 'modelMessageStopEvent') + expect(messageStopEvent).toBeDefined() + expect(messageStopEvent?.stopReason).toBe('endTurn') + }) + + it.concurrent('returns maxTokens stop reason when token limit reached', async () => { + const provider = openai.createModel({ + modelId: 'gpt-5.4-mini', + maxTokens: 10, // Very small limit to force cutoff + }) + + const messages: Message[] = [ + new Message({ + role: 'user', + content: [new TextBlock('Write a very long story about dragons.')], + }), + ] + + const events = await collectIterator(provider.stream(messages)) + + const messageStopEvent = events.find((e) => e.type === 'modelMessageStopEvent') + expect(messageStopEvent).toBeDefined() + expect(messageStopEvent?.stopReason).toBe('maxTokens') + }) + + it.concurrent('returns toolUse stop reason when requesting tool use', async () => { + const provider = openai.createModel({ + modelId: 'gpt-5.4-mini', + maxTokens: 200, + }) + + const calculatorTool: ToolSpec = { + name: 'calculator', + description: 'Performs basic arithmetic operations. Use this to calculate math expressions.', + inputSchema: { + type: 'object', + properties: { + expression: { type: 'string', description: 'The math expression to calculate' }, + }, + required: ['expression'], + }, + } + + const messages: Message[] = [ + new Message({ + role: 'user', + content: [new TextBlock('Calculate 42 times 7 please.')], + }), + ] + + const events = await collectIterator(provider.stream(messages, { toolSpecs: [calculatorTool] })) + + const messageStopEvent = events.find((e) => e.type === 'modelMessageStopEvent') + expect(messageStopEvent).toBeDefined() + expect(messageStopEvent?.stopReason).toBe('toolUse') + }) + }) +}) diff --git a/strands-ts/test/integ/models/openai/mantle.test.node.ts b/strands-ts/test/integ/models/openai/mantle.test.node.ts new file mode 100644 index 0000000000..e2dbd51fcd --- /dev/null +++ b/strands-ts/test/integ/models/openai/mantle.test.node.ts @@ -0,0 +1,93 @@ +/** + * Integration tests for the OpenAI-compatible Bedrock Mantle pathway. + * + * Exercises `OpenAIModel` with `bedrockMantleConfig` against the live + * `bedrock-mantle..api.aws/v1` endpoint. Credentials come from the + * ambient AWS credential chain (same gate as the other Bedrock integ tests). + */ + +import { describe, expect, it } from 'vitest' +import { Agent } from '@strands-agents/sdk' +import { OpenAIModel } from '$/sdk/models/openai/index.js' + +import { bedrock } from '../../__fixtures__/model-providers.js' + +const REGION = 'us-east-1' +const MODEL_ID = 'openai.gpt-oss-120b' + +describe.skipIf(bedrock.skip)('OpenAIModel (Bedrock Mantle) Integration Tests', () => { + it('reaches Mantle via bedrockMantleConfig on the Chat Completions API', async () => { + const model = new OpenAIModel({ + api: 'chat', + modelId: MODEL_ID, + bedrockMantleConfig: { region: REGION }, + }) + const agent = new Agent({ + model, + systemPrompt: 'Reply in one short sentence.', + printer: false, + }) + + const result = await agent.invoke('What is 2+2?') + + expect(result.stopReason).toBe('endTurn') + expect(String(result)).toContain('4') + }) + + it('reaches Mantle via bedrockMantleConfig on the Responses API', async () => { + const model = new OpenAIModel({ + modelId: MODEL_ID, + bedrockMantleConfig: { region: REGION }, + }) + const agent = new Agent({ + model, + systemPrompt: 'Reply in one short sentence.', + printer: false, + }) + + const result = await agent.invoke('What is 2+2?') + + expect(result.stopReason).toBe('endTurn') + expect(String(result)).toContain('4') + }) + + it('supports server-side stateful conversations', async () => { + const model = new OpenAIModel({ + modelId: MODEL_ID, + stateful: true, + bedrockMantleConfig: { region: REGION }, + }) + const agent = new Agent({ + model, + systemPrompt: 'Reply in one short sentence.', + printer: false, + }) + + await agent.invoke('My name is Alice.') + const result = await agent.invoke('What is my name?') + + expect(String(result).toLowerCase()).toContain('alice') + }) + + it('handles reasoning content across multi-turn conversations', async () => { + const model = new OpenAIModel({ + modelId: MODEL_ID, + bedrockMantleConfig: { region: REGION }, + params: { reasoning: { effort: 'low' } }, + }) + const agent = new Agent({ + model, + systemPrompt: 'Reply in one short sentence.', + printer: false, + }) + + const first = await agent.invoke('What is 2+2?') + expect(String(first)).toContain('4') + + // Second turn must not throw despite reasoningContent blocks present in + // the message history. The local response shape varies by effort level, + // so we only assert the round-trip completes cleanly. + const second = await agent.invoke('What about 3+3?') + expect(second.stopReason).toBe('endTurn') + }) +}) diff --git a/strands-ts/test/integ/models/openai/responses.test.ts b/strands-ts/test/integ/models/openai/responses.test.ts new file mode 100644 index 0000000000..3631bfcacd --- /dev/null +++ b/strands-ts/test/integ/models/openai/responses.test.ts @@ -0,0 +1,335 @@ +import { describe, expect, it } from 'vitest' +import { z } from 'zod' +import type { ToolSpec } from '@strands-agents/sdk' +import { Agent, Message, TextBlock, tool } from '@strands-agents/sdk' + +import { collectIterator } from '$/sdk/__fixtures__/model-test-helpers.js' + +import { openaiResponses } from '../../__fixtures__/model-providers.js' + +describe.skipIf(openaiResponses.skip)("OpenAIModel (api: 'responses') Integration Tests", () => { + describe('Configuration', () => { + it.concurrent('respects maxTokens configuration', async () => { + const provider = openaiResponses.createModel({ + modelId: 'gpt-5.4-mini', + maxTokens: 20, + }) + + const messages: Message[] = [ + new Message({ + role: 'user', + content: [new TextBlock('Write a long story about dragons.')], + }), + ] + + const events = await collectIterator(provider.stream(messages)) + + const metadataEvent = events.find((e) => e.type === 'modelMetadataEvent') + expect(metadataEvent?.usage?.outputTokens).toBeLessThanOrEqual(25) + + const messageStopEvent = events.find((e) => e.type === 'modelMessageStopEvent') + expect(messageStopEvent?.stopReason).toBe('maxTokens') + }) + + it.concurrent('respects temperature configuration', async () => { + const provider = openaiResponses.createModel({ + modelId: 'gpt-5.4-mini', + temperature: 0, + maxTokens: 50, + }) + + const messages: Message[] = [ + new Message({ + role: 'user', + content: [new TextBlock('Say "hello world" exactly.')], + }), + ] + + const events = await collectIterator(provider.stream(messages)) + + let text = '' + for (const event of events) { + if (event.type === 'modelContentBlockDeltaEvent' && event.delta.type === 'textDelta') { + text += event.delta.text + } + } + + expect(text.toLowerCase()).toContain('hello') + }) + }) + + describe('Content Block Lifecycle', () => { + it.concurrent('emits complete content block lifecycle events', async () => { + const provider = openaiResponses.createModel({ + modelId: 'gpt-5.4-mini', + maxTokens: 50, + }) + + const messages: Message[] = [ + new Message({ + role: 'user', + content: [new TextBlock('Say hello.')], + }), + ] + + const events = await collectIterator(provider.stream(messages)) + + const startEvents = events.filter((e) => e.type === 'modelContentBlockStartEvent') + const deltaEvents = events.filter((e) => e.type === 'modelContentBlockDeltaEvent') + const stopEvents = events.filter((e) => e.type === 'modelContentBlockStopEvent') + + expect(startEvents.length).toBeGreaterThan(0) + expect(deltaEvents.length).toBeGreaterThan(0) + expect(stopEvents.length).toBeGreaterThan(0) + + const startIndex = events.findIndex((e) => e.type === 'modelContentBlockStartEvent') + const firstDeltaIndex = events.findIndex((e) => e.type === 'modelContentBlockDeltaEvent') + expect(startIndex).toBeLessThan(firstDeltaIndex) + + const stopIndex = events.findIndex((e) => e.type === 'modelContentBlockStopEvent') + const lastDeltaIndex = events + .map((e, i) => (e.type === 'modelContentBlockDeltaEvent' ? i : -1)) + .filter((i) => i !== -1) + .pop()! + expect(stopIndex).toBeGreaterThan(lastDeltaIndex) + }) + }) + + describe('Stop Reasons', () => { + it.concurrent('returns endTurn stop reason for natural completion', async () => { + const provider = openaiResponses.createModel({ + modelId: 'gpt-5.4-mini', + maxTokens: 100, + }) + + const messages: Message[] = [ + new Message({ + role: 'user', + content: [new TextBlock('Say hi.')], + }), + ] + + const events = await collectIterator(provider.stream(messages)) + + const messageStopEvent = events.find((e) => e.type === 'modelMessageStopEvent') + expect(messageStopEvent?.stopReason).toBe('endTurn') + }) + + it.concurrent('returns maxTokens stop reason when token limit reached', async () => { + const provider = openaiResponses.createModel({ + modelId: 'gpt-5.4-mini', + maxTokens: 16, + }) + + const messages: Message[] = [ + new Message({ + role: 'user', + content: [new TextBlock('Write a very long story about dragons.')], + }), + ] + + const events = await collectIterator(provider.stream(messages)) + + const messageStopEvent = events.find((e) => e.type === 'modelMessageStopEvent') + expect(messageStopEvent?.stopReason).toBe('maxTokens') + }) + + it.concurrent('returns toolUse stop reason when requesting tool use', async () => { + const provider = openaiResponses.createModel({ + modelId: 'gpt-5.4-mini', + maxTokens: 200, + }) + + const calculatorTool: ToolSpec = { + name: 'calculator', + description: 'Performs basic arithmetic operations. Use this to calculate math expressions.', + inputSchema: { + type: 'object', + properties: { + expression: { type: 'string', description: 'The math expression to calculate' }, + }, + required: ['expression'], + }, + } + + const messages: Message[] = [ + new Message({ + role: 'user', + content: [new TextBlock('Calculate 42 times 7 please.')], + }), + ] + + const events = await collectIterator(provider.stream(messages, { toolSpecs: [calculatorTool] })) + + const messageStopEvent = events.find((e) => e.type === 'modelMessageStopEvent') + expect(messageStopEvent?.stopReason).toBe('toolUse') + }) + }) + + describe('Stateful Conversation', () => { + it('tracks conversation across turns via server-side state', async () => { + const model = openaiResponses.createModel({ + modelId: 'gpt-5.4-mini', + stateful: true, + }) + const agent = new Agent({ + model, + printer: false, + systemPrompt: 'Reply in one short sentence.', + }) + + await agent.invoke('My name is Alice.') + expect(agent.messages).toHaveLength(0) + + const result = await agent.invoke('What is my name?') + const text = result.lastMessage.content + .filter((block) => block.type === 'textBlock') + .map((block) => block.text) + .join('') + .toLowerCase() + expect(text).toContain('alice') + }) + + it('completes an agent-loop round-trip with a user-defined function tool', async () => { + // Exercises the stateful + function-tool wire path end-to-end: the agent + // executes the callback, then sends a second Responses request carrying + // previous_response_id plus a function_call_output item. Nothing else in + // this suite covers that follow-up request — the existing toolUse test + // stops at the first chunk, and the built-in tool tests (web_search / + // code_interpreter) use a different serialization path. Assertions are + // purely mechanical to stay deterministic. + let callCount = 0 + const pingTool = tool({ + name: 'ping', + description: 'Returns a fixed acknowledgement. Use this when the user asks you to ping.', + inputSchema: z.object({}), + callback: async () => { + callCount++ + return 'pong' + }, + }) + + const model = openaiResponses.createModel({ + modelId: 'gpt-5.4-mini', + stateful: true, + }) + const agent = new Agent({ + model, + printer: false, + systemPrompt: 'Use the ping tool when asked to ping.', + tools: [pingTool], + }) + + const result = await agent.invoke('Please ping.') + + expect(result.stopReason).toBe('endTurn') + expect(callCount).toBeGreaterThanOrEqual(1) + expect(result.metrics?.toolMetrics['ping']?.successCount).toBeGreaterThanOrEqual(1) + expect(agent.messages).toEqual([]) + expect(agent.modelState.get('responseId')).toEqual(expect.any(String)) + }) + }) + + describe('Built-in Tools', () => { + it.concurrent('web_search produces text with citations', async () => { + const model = openaiResponses.createModel({ + modelId: 'gpt-4o', + params: { tools: [{ type: 'web_search' }] }, + }) + const agent = new Agent({ + model, + printer: false, + systemPrompt: 'Answer concisely.', + }) + + const result = await agent.invoke('Search https://strandsagents.com/ and tell me what Strands Agents is.') + const citationsBlock = result.lastMessage.content.find((block) => block.type === 'citationsBlock') + expect(citationsBlock).toBeDefined() + }) + + it.concurrent('code_interpreter produces correct results', async () => { + const model = openaiResponses.createModel({ + modelId: 'gpt-4o', + params: { tools: [{ type: 'code_interpreter', container: { type: 'auto' } }] }, + }) + const agent = new Agent({ + model, + printer: false, + systemPrompt: 'Answer concisely.', + }) + + const result = await agent.invoke("Compute the SHA-256 hash of the string 'strands'. Return only the hex digest.") + const text = result.lastMessage.content + .filter((block) => block.type === 'textBlock') + .map((block) => block.text) + .join('') + expect(text).toContain('11e0e34bd35e12185cfacd5e5a256ab4292bfa3616d8d5b74e20eca36feed228') + }) + }) + + describe('Citation Block Switching', () => { + it.concurrent('text and citations land in separate content blocks', async () => { + const provider = openaiResponses.createModel({ + modelId: 'gpt-4o', + params: { tools: [{ type: 'web_search' }] }, + }) + + const messages: Message[] = [ + new Message({ + role: 'user', + content: [new TextBlock('Search the web and tell me what Strands Agents is. Cite your sources.')], + }), + ] + + const events = await collectIterator(provider.stream(messages)) + + const textDeltas = events.filter((e) => e.type === 'modelContentBlockDeltaEvent' && e.delta.type === 'textDelta') + const citationDeltas = events.filter( + (e) => e.type === 'modelContentBlockDeltaEvent' && e.delta.type === 'citationsDelta' + ) + + expect(textDeltas.length).toBeGreaterThan(0) + expect(citationDeltas.length).toBeGreaterThan(0) + + // Every citation delta must be preceded by a block start (not a text delta in the same block). + // This verifies the _switchContent('citations', ...) logic closes the text block first. + for (const citationDelta of citationDeltas) { + const citationIndex = events.indexOf(citationDelta) + const precedingEvents = events.slice(0, citationIndex) + + let lastStart = -1 + let lastTextDelta = -1 + for (let i = 0; i < precedingEvents.length; i++) { + const ev = precedingEvents[i]! + if (ev.type === 'modelContentBlockStartEvent') lastStart = i + if (ev.type === 'modelContentBlockDeltaEvent' && ev.delta.type === 'textDelta') lastTextDelta = i + } + + if (lastTextDelta !== -1) { + expect(lastStart).toBeGreaterThan(lastTextDelta) + } + } + }) + }) + + describe('Error Handling', () => { + it.concurrent('handles invalid model ID gracefully', async () => { + const provider = openaiResponses.createModel({ + modelId: 'invalid-model-id-that-does-not-exist-xyz', + }) + + const messages: Message[] = [ + new Message({ + role: 'user', + content: [new TextBlock('Hello')], + }), + ] + + await expect(async () => { + for await (const _event of provider.stream(messages)) { + throw Error('Should not get here') + } + }).rejects.toThrow() + }) + }) +}) diff --git a/strands-ts/test/integ/multiagent/_interrupt-helpers.ts b/strands-ts/test/integ/multiagent/_interrupt-helpers.ts new file mode 100644 index 0000000000..5ce32d0dd0 --- /dev/null +++ b/strands-ts/test/integ/multiagent/_interrupt-helpers.ts @@ -0,0 +1,35 @@ +/** + * Shared helpers for multi-agent interrupt integration tests. + * + * Leading-underscore filename keeps it out of vitest's auto-discovery. + */ +import type { InterruptResponseContentData, JSONValue } from '@strands-agents/sdk' +import { Status } from '$/sdk/multiagent/index.js' +import type { MultiAgentResult } from '$/sdk/multiagent/index.js' +import { SessionManager } from '$/sdk/session/session-manager.js' +import { FileStorage } from '$/sdk/session/file-storage.js' + +export function makeSessionManager(sessionId: string, storageDir: string): SessionManager { + return new SessionManager({ sessionId, storage: { snapshot: new FileStorage(storageDir) } }) +} + +/** + * Resumes an interrupted orchestrator by answering all pending interrupts, looping + * until the run terminates or we hit the max iteration limit. Used for both Graph + * and Swarm. + */ +export async function resumeUntilDone( + invoke: (responses: InterruptResponseContentData[]) => Promise, + initial: MultiAgentResult, + respond: (interrupt: { id: string; name: string; reason?: unknown }) => JSONValue, + maxRounds = 5 +): Promise { + let current = initial + for (let i = 0; i < maxRounds && current.status === Status.INTERRUPTED; i++) { + const responses: InterruptResponseContentData[] = current.interrupts!.map((interrupt) => ({ + interruptResponse: { interruptId: interrupt.id, response: respond(interrupt) }, + })) + current = await invoke(responses) + } + return current +} diff --git a/strands-ts/test/integ/multiagent/graph.test.ts b/strands-ts/test/integ/multiagent/graph.test.ts new file mode 100644 index 0000000000..3a78111b80 --- /dev/null +++ b/strands-ts/test/integ/multiagent/graph.test.ts @@ -0,0 +1,213 @@ +import { describe, expect, it } from 'vitest' +import { Agent } from '$/sdk/agent/agent.js' +import { Graph, Swarm, Status } from '$/sdk/multiagent/index.js' +import { collectGenerator } from '$/sdk/__fixtures__/model-test-helpers.js' +import { bedrock } from '../__fixtures__/model-providers.js' + +describe.skipIf(bedrock.skip)('Graph', () => { + const createModel = (maxTokens = 1024) => bedrock.createModel({ maxTokens }) + + it('completes single-node execution with lifecycle events', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + id: 'assistant', + systemPrompt: 'Answer in one word only.', + }) + + const graph = new Graph({ + nodes: [agent], + edges: [], + }) + + const { items, result } = await collectGenerator(graph.stream('What is the capital of France?')) + + expect(result).toEqual( + expect.objectContaining({ + status: Status.COMPLETED, + duration: expect.any(Number), + }) + ) + expect(result.results).toHaveLength(1) + expect(result.results[0]!.nodeId).toBe('assistant') + + const text = result.content.find((b) => b.type === 'textBlock') + expect(text?.text).toMatch(/Paris/i) + + const eventTypes = items.map((e) => e.type) + expect(eventTypes[0]).toBe('beforeMultiAgentInvocationEvent') + expect(eventTypes).toContain('beforeNodeCallEvent') + expect(eventTypes).toContain('nodeStreamUpdateEvent') + expect(eventTypes).toContain('nodeResultEvent') + expect(eventTypes).toContain('afterNodeCallEvent') + expect(eventTypes).toContain('afterMultiAgentInvocationEvent') + expect(eventTypes).toContain('multiAgentResultEvent') + }) + + it('executes linear graph with handoff events', async () => { + const researcher = new Agent({ + model: createModel(), + printer: false, + id: 'researcher', + systemPrompt: 'Research the topic and provide key facts in 1-2 sentences.', + }) + + const writer = new Agent({ + model: createModel(), + printer: false, + id: 'writer', + systemPrompt: 'Rewrite the input as a single polished sentence.', + }) + + const graph = new Graph({ + nodes: [researcher, writer], + edges: [['researcher', 'writer']], + }) + + const { items, result } = await collectGenerator(graph.stream('What is the largest ocean?')) + + expect(result).toEqual( + expect.objectContaining({ + status: Status.COMPLETED, + duration: expect.any(Number), + }) + ) + expect(result.results.map((r) => r.nodeId)).toStrictEqual(['researcher', 'writer']) + + const text = result.content.find((b) => b.type === 'textBlock') + expect(text?.text).toMatch(/Pacific/i) + + const handoff = items.find((e) => e.type === 'multiAgentHandoffEvent') + expect(handoff).toEqual( + expect.objectContaining({ + source: 'researcher', + targets: ['writer'], + }) + ) + }) + + it('executes parallel fan-out graph', async () => { + const router = new Agent({ + model: createModel(), + printer: false, + id: 'router', + systemPrompt: 'Repeat the user input exactly.', + }) + + const capitals = new Agent({ + model: createModel(), + printer: false, + id: 'capitals', + systemPrompt: 'Answer with only the capital of France in one word.', + }) + + const oceans = new Agent({ + model: createModel(), + printer: false, + id: 'oceans', + systemPrompt: 'Answer with only the largest ocean in one word.', + }) + + const graph = new Graph({ + nodes: [router, capitals, oceans], + edges: [ + ['router', 'capitals'], + ['router', 'oceans'], + ], + }) + + const result = await graph.invoke('Go') + + expect(result).toEqual( + expect.objectContaining({ + status: Status.COMPLETED, + duration: expect.any(Number), + }) + ) + expect(result.results).toHaveLength(3) + expect(result.results.map((r) => r.nodeId).sort()).toStrictEqual(['capitals', 'oceans', 'router']) + + const text = result.content.map((b) => (b.type === 'textBlock' ? b.text : '')).join(' ') + expect(text).toMatch(/Paris/i) + expect(text).toMatch(/Pacific/i) + }) + + it('executes nested graph through MultiAgentNode', async () => { + const inner = new Swarm({ + id: 'inner-swarm', + nodes: [ + new Agent({ + model: createModel(), + printer: false, + id: 'answerer', + description: 'Answers questions in one word.', + systemPrompt: 'Answer in one word only.', + }), + ], + start: 'answerer', + }) + + const summarizer = new Agent({ + model: createModel(), + printer: false, + id: 'summarizer', + systemPrompt: 'Repeat the input exactly as given.', + }) + + const graph = new Graph({ + nodes: [inner, summarizer], + edges: [['inner-swarm', 'summarizer']], + }) + + const result = await graph.invoke('What is the capital of Japan?') + + expect(result).toEqual( + expect.objectContaining({ + status: Status.COMPLETED, + duration: expect.any(Number), + }) + ) + expect(result.results.map((r) => r.nodeId)).toStrictEqual(['inner-swarm', 'summarizer']) + + const text = result.content.find((b) => b.type === 'textBlock') + expect(text?.text).toMatch(/Tokyo/i) + }) + + it('executes cycle with conditional edge that breaks after one iteration', async () => { + let visits = 0 + + const agent = new Agent({ + model: createModel(), + printer: false, + id: 'counter', + systemPrompt: 'Reply with the single word "counted".', + }) + + const graph = new Graph({ + nodes: [agent], + edges: [ + { + source: 'counter', + target: 'counter', + handler: () => { + visits++ + return visits < 2 + }, + }, + ], + sources: ['counter'], + }) + + const result = await graph.invoke('Go') + + expect(result).toEqual( + expect.objectContaining({ + status: Status.COMPLETED, + duration: expect.any(Number), + }) + ) + expect(result.results).toHaveLength(2) + expect(result.results.map((r) => r.nodeId)).toStrictEqual(['counter', 'counter']) + expect(visits).toBe(2) + }) +}) diff --git a/strands-ts/test/integ/multiagent/interrupt-hook.test.node.ts b/strands-ts/test/integ/multiagent/interrupt-hook.test.node.ts new file mode 100644 index 0000000000..79229f2515 --- /dev/null +++ b/strands-ts/test/integ/multiagent/interrupt-hook.test.node.ts @@ -0,0 +1,106 @@ +/** + * Integration tests for orchestrator-hook interrupts — interrupts raised from + * `BeforeNodeCallEvent.interrupt()` to gate a node before it runs. + */ +import { describe, expect, it } from 'vitest' +import { z } from 'zod' +import { Agent, tool } from '@strands-agents/sdk' +import { Graph, Swarm, Status, BeforeNodeCallEvent } from '$/sdk/multiagent/index.js' +import { bedrock } from '../__fixtures__/model-providers.js' +import { resumeUntilDone } from './_interrupt-helpers.js' + +const weatherTool = tool({ + name: 'weather_tool', + description: 'Returns the current weather.', + inputSchema: z.object({}), + callback: async () => 'sunny', +}) + +describe.skipIf(bedrock.skip)('Multi-agent orchestrator-hook interrupts', () => { + const createModel = (maxTokens = 1024) => bedrock.createModel({ maxTokens }) + + it('Graph: hook gates a node before it runs, resume approves', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + id: 'execute', + tools: [weatherTool], + systemPrompt: 'Use the tool and briefly answer.', + }) + + const graph = new Graph({ nodes: [agent], edges: [] }) + graph.addHook(BeforeNodeCallEvent, (event) => { + if (event.nodeId !== 'execute') return + const response = event.interrupt({ name: 'execute_approval', reason: 'approve?' }) + if (response !== 'APPROVE') event.cancel = 'rejected' + }) + + const result = await graph.invoke('What is the weather?') + expect(result.status).toBe(Status.INTERRUPTED) + expect(result.interrupts![0]!.source).toBe('multiagent-hook') + expect(result.interrupts![0]!.name).toBe('execute_approval') + + const finalResult = await resumeUntilDone( + (responses) => graph.invoke(responses), + result, + () => 'APPROVE' + ) + expect(finalResult.status).toBe(Status.COMPLETED) + }) + + it('Graph: hook rejection cancels the node', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + id: 'execute', + tools: [weatherTool], + }) + + const graph = new Graph({ nodes: [agent], edges: [] }) + graph.addHook(BeforeNodeCallEvent, (event) => { + if (event.nodeId !== 'execute') return + const response = event.interrupt({ name: 'execute_approval', reason: 'approve?' }) + if (response !== 'APPROVE') event.cancel = 'rejected' + }) + + const result = await graph.invoke('anything') + expect(result.status).toBe(Status.INTERRUPTED) + + const finalResult = await resumeUntilDone( + (responses) => graph.invoke(responses), + result, + () => 'REJECT' + ) + const executeResult = finalResult.results.find((r) => r.nodeId === 'execute') + expect(executeResult?.status).toBe(Status.CANCELLED) + }) + + it('Swarm: hook gates the start node, resume approves', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + id: 'assistant', + description: 'Answers questions briefly.', + systemPrompt: 'Answer in one word only.', + }) + + const swarm = new Swarm({ nodes: [agent], start: 'assistant' }) + swarm.addHook(BeforeNodeCallEvent, (event) => { + event.interrupt({ name: 'approval', reason: 'approve?' }) + }) + + const result = await swarm.invoke('What is the capital of France?') + expect(result.status).toBe(Status.INTERRUPTED) + expect(result.interrupts![0]!.source).toBe('multiagent-hook') + + const finalResult = await resumeUntilDone( + (responses) => swarm.invoke(responses), + result, + () => 'APPROVE' + ) + expect(finalResult.status).toBe(Status.COMPLETED) + + const text = finalResult.content.find((b) => b.type === 'textBlock') + expect(text?.text).toMatch(/Paris/i) + }) +}) diff --git a/strands-ts/test/integ/multiagent/interrupt-node.test.node.ts b/strands-ts/test/integ/multiagent/interrupt-node.test.node.ts new file mode 100644 index 0000000000..a9c1fff90e --- /dev/null +++ b/strands-ts/test/integ/multiagent/interrupt-node.test.node.ts @@ -0,0 +1,79 @@ +/** + * Integration tests for interrupts raised by tool callbacks inside a node's child + * agent. Exercises Graph and Swarm routing the interrupt up through the orchestrator + * result, then resuming with a response. + */ +import { describe, expect, it } from 'vitest' +import { z } from 'zod' +import { Agent, tool } from '@strands-agents/sdk' +import { Graph, Swarm, Status } from '$/sdk/multiagent/index.js' +import { bedrock } from '../__fixtures__/model-providers.js' +import { resumeUntilDone } from './_interrupt-helpers.js' + +const interruptingWeatherTool = tool({ + name: 'weather_tool', + description: 'Returns the current weather.', + inputSchema: z.object({}), + callback: async (_input, context) => + context!.interrupt({ name: 'weather_interrupt', reason: 'need weather' }) as string, +}) + +describe.skipIf(bedrock.skip)('Multi-agent tool-callback interrupts', () => { + const createModel = (maxTokens = 1024) => bedrock.createModel({ maxTokens }) + + it('Graph: tool inside a node interrupts, resumes', async () => { + const weatherAgent = new Agent({ + model: createModel(), + printer: false, + id: 'weather', + tools: [interruptingWeatherTool], + systemPrompt: 'Use the weather tool to answer the user.', + }) + + const graph = new Graph({ nodes: [weatherAgent], edges: [] }) + + const result = await graph.invoke('What is the weather?') + expect(result.status).toBe(Status.INTERRUPTED) + expect(result.interrupts).toBeDefined() + expect(result.interrupts![0]!.name).toBe('weather_interrupt') + expect(result.interrupts![0]!.source).toBe('tool') + + const finalResult = await resumeUntilDone( + (responses) => graph.invoke(responses), + result, + () => 'cloudy' + ) + expect(finalResult.status).toBe(Status.COMPLETED) + + const text = finalResult.content + .filter((b) => b.type === 'textBlock') + .map((b) => b.text) + .join(' ') + .toLowerCase() + expect(text).toMatch(/cloudy/) + }) + + it('Swarm: tool inside the start agent interrupts, resumes', async () => { + const weatherAgent = new Agent({ + model: createModel(), + printer: false, + id: 'weather', + tools: [interruptingWeatherTool], + description: 'Fetches weather data.', + systemPrompt: 'Use the weather tool, then produce a final response with no handoff.', + }) + + const swarm = new Swarm({ nodes: [weatherAgent], start: 'weather' }) + + const result = await swarm.invoke('What is the weather?') + expect(result.status).toBe(Status.INTERRUPTED) + expect(result.interrupts![0]!.source).toBe('tool') + + const finalResult = await resumeUntilDone( + (responses) => swarm.invoke(responses), + result, + () => 'cloudy' + ) + expect(finalResult.status).toBe(Status.COMPLETED) + }) +}) diff --git a/strands-ts/test/integ/multiagent/interrupt-session.test.node.ts b/strands-ts/test/integ/multiagent/interrupt-session.test.node.ts new file mode 100644 index 0000000000..ef2503ce9a --- /dev/null +++ b/strands-ts/test/integ/multiagent/interrupt-session.test.node.ts @@ -0,0 +1,69 @@ +/** + * Integration tests for multi-agent interrupt round-trip through a SessionManager: + * a fresh orchestrator instance picks up where the previous one paused, with state + * restored from `FileStorage`. + */ +import { describe, expect, it, beforeAll, afterAll } from 'vitest' +import { promises as fs } from 'fs' +import { join } from 'path' +import { tmpdir } from 'os' +import { v7 as uuidv7 } from 'uuid' +import { z } from 'zod' +import { Agent } from '$/sdk/agent/agent.js' +import { tool } from '$/sdk/tools/tool-factory.js' +import { TextBlock } from '$/sdk/types/messages.js' +import { Graph, Status } from '$/sdk/multiagent/index.js' +import { bedrock } from '../__fixtures__/model-providers.js' +import { makeSessionManager } from './_interrupt-helpers.js' + +const interruptingWeatherTool = tool({ + name: 'weather_tool', + description: 'Returns the current weather.', + inputSchema: z.object({}), + callback: async (_input, context) => + context!.interrupt({ name: 'weather_interrupt', reason: 'need weather' }) as string, +}) + +describe.skipIf(bedrock.skip)('Multi-agent interrupt session round-trip', () => { + const createModel = (maxTokens = 1024) => bedrock.createModel({ maxTokens }) + + let storageDir: string + beforeAll(async () => { + storageDir = join(tmpdir(), `strands-multiagent-interrupt-session-${uuidv7()}`) + await fs.mkdir(storageDir, { recursive: true }) + }) + afterAll(async () => { + await fs.rm(storageDir, { recursive: true, force: true }) + }) + + it('Graph: tool-interrupt persists and resumes with fresh orchestrator', async () => { + const sessionId = `graph-tool-${uuidv7()}` + const buildGraph = (): Graph => { + const agent = new Agent({ + model: createModel(), + printer: false, + id: 'weather', + tools: [interruptingWeatherTool], + systemPrompt: 'Use the weather tool then answer.', + }) + return new Graph({ + nodes: [agent], + edges: [], + sessionManager: makeSessionManager(sessionId, storageDir), + }) + } + + // Pass a ContentBlock[] so the invocation input round-trips through + // FileStorage JSON as block data and rehydrates into a valid agent message + // when the node runs on resume. + const firstResult = await buildGraph().invoke([new TextBlock('What is the weather?')]) + expect(firstResult.status).toBe(Status.INTERRUPTED) + expect(firstResult.interrupts![0]!.source).toBe('tool') + + const interrupt = firstResult.interrupts![0]! + const finalResult = await buildGraph().invoke([ + { interruptResponse: { interruptId: interrupt.id, response: 'cloudy' } }, + ]) + expect(finalResult.status).toBe(Status.COMPLETED) + }) +}) diff --git a/strands-ts/test/integ/multiagent/session-manager.test.node.ts b/strands-ts/test/integ/multiagent/session-manager.test.node.ts new file mode 100644 index 0000000000..a8e0a5c58e --- /dev/null +++ b/strands-ts/test/integ/multiagent/session-manager.test.node.ts @@ -0,0 +1,338 @@ +/** + * Integration tests for multi-agent session management (Swarm and Graph resume). + * Node-only: uses FileStorage which requires fs. + * + */ +import { describe, expect, it, beforeAll, afterAll } from 'vitest' +import { promises as fs } from 'fs' +import { join } from 'path' +import { tmpdir } from 'os' +import { v7 as uuidv7 } from 'uuid' +import { Agent } from '$/sdk/agent/agent.js' +import { + Swarm, + Status, + Graph, + BeforeNodeCallEvent, + BeforeMultiAgentInvocationEvent, + MultiAgentState, +} from '$/sdk/multiagent/index.js' +import type { EdgeDefinition } from '$/sdk/multiagent/index.js' +import { SessionManager } from '$/sdk/session/session-manager.js' +import { FileStorage } from '$/sdk/session/file-storage.js' +import { bedrock } from '../__fixtures__/model-providers.js' + +// ─── Helpers ───────────────────────────────────────────────────────────────── + +function makeSessionManager(sessionId: string, storageDir: string): SessionManager { + return new SessionManager({ sessionId, storage: { snapshot: new FileStorage(storageDir) } }) +} + +function createResearcherWriterNodes(createModel: () => ReturnType) { + return [ + new Agent({ + model: createModel(), + printer: false, + id: 'researcher', + description: 'Researches a topic then hands off to the writer.', + systemPrompt: + 'You are a researcher. Research the answer, then always hand off to the writer. Never produce a final response yourself.', + }), + new Agent({ + model: createModel(), + printer: false, + id: 'writer', + description: 'Writes a polished final answer in one sentence.', + systemPrompt: 'Write the final answer in one sentence. Do not hand off.', + }), + ] +} + +// ─── Swarm Resume ──────────────────────────────────────────────────────────── + +describe.skipIf(bedrock.skip)('Multi-Agent Session Management - Swarm', () => { + const createModel = (maxTokens = 1024) => bedrock.createModel({ maxTokens }) + let tempDir: string + + beforeAll(async () => { + tempDir = join(tmpdir(), `strands-multiagent-session-integ-${Date.now()}`) + await fs.mkdir(tempDir, { recursive: true }) + }) + + afterAll(async () => { + await fs.rm(tempDir, { recursive: true, force: true }) + }) + + it('resumes from the pending handoff target after maxSteps stops the swarm', async () => { + const sessionId = uuidv7() + const swarmId = 'resume-swarm' + + // First invocation: researcher hands off to writer, but maxSteps=1 stops before writer runs + const swarm1 = new Swarm({ + id: swarmId, + nodes: createResearcherWriterNodes(createModel), + start: 'researcher', + maxSteps: 1, + plugins: [makeSessionManager(sessionId, tempDir)], + }) + + await expect(swarm1.invoke('What is the tallest mountain?')).rejects.toThrow('swarm reached step limit') + + // Second invocation: new Swarm + SessionManager simulates process restart + const swarm2 = new Swarm({ + id: swarmId, + nodes: createResearcherWriterNodes(createModel), + start: 'researcher', + plugins: [makeSessionManager(sessionId, tempDir)], + }) + + const result = await swarm2.invoke('What is the tallest mountain?') + + expect(result.status).toBe(Status.COMPLETED) + expect(result.results.map((r) => r.nodeId)).toStrictEqual(['researcher', 'writer']) + + const text = result.content.find((b) => b.type === 'textBlock') + expect(text?.text).toMatch(/Everest/i) + }) +}) + +// ─── Graph Resume ──────────────────────────────────────────────────────────── +describe.skipIf(bedrock.skip)('Multi-Agent Session Management - Graph', () => { + const createModel = (maxTokens = 1024) => bedrock.createModel({ maxTokens }) + let tempDir: string + + beforeAll(async () => { + tempDir = join(tmpdir(), `strands-graph-session-integ-${Date.now()}`) + await fs.mkdir(tempDir, { recursive: true }) + }) + + afterAll(async () => { + await fs.rm(tempDir, { recursive: true, force: true }) + }) + + /** + * Graph topology (parallel branches + sub-graph): + * + * researcher ──→ analyst ──→ reporter + * │ ↑ + * └──→ sub-graph ───────────┘ + * (drafter → reviewer) + * + * - `researcher` is the source node. + * - `analyst` and `sub-graph` run in parallel after researcher completes. + * - `reporter` waits for both analyst AND sub-graph (AND-join). + * - `sub-graph` is a nested Graph with two nodes: drafter → reviewer. + * + * First run: researcher and analyst complete, sub-graph is cancelled via hook. + * Resume: sub-graph executes (dep researcher=COMPLETED), then reporter fires + * (deps analyst=COMPLETED, sub-graph=COMPLETED). + * + * This tests: + * - sessionManager constructor arg (not plugins) + * - parallel execution (default maxConcurrency) + * - sub-graph (MultiAgentNode) resume + * - AND-join dependency resolution across resume boundary + * - cross-boundary data flow (reporter receives outputs from both runs) + */ + it('resumes graph with parallel branches and sub-graph across session boundary', async () => { + const sessionId = uuidv7() + const graphId = 'resume-subgraph' + + function makeAgent(id: string, prompt: string) { + return new Agent({ model: createModel(), printer: false, id, systemPrompt: prompt }) + } + + function createSubGraph() { + return new Graph({ + id: 'sub-graph', + nodes: [ + makeAgent('drafter', 'You are a drafter. Write a one-sentence draft about the topic.'), + makeAgent( + 'reviewer', + 'You are a reviewer. Improve the draft in one sentence. Mention "Everest" if the topic is about mountains.' + ), + ], + edges: [['drafter', 'reviewer']], + }) + } + + function createNodes() { + return [ + makeAgent('researcher', 'You are a researcher. State the topic of the question in one sentence.'), + makeAgent('analyst', 'You are an analyst. Add one key fact about the topic from the researcher.'), + createSubGraph(), + makeAgent( + 'reporter', + 'You are a reporter. Combine all inputs into a final two-sentence summary. Mention "Everest" if the topic is about mountains.' + ), + ] + } + + const edges: [string, string][] = [ + ['researcher', 'analyst'], + ['researcher', 'sub-graph'], + ['analyst', 'reporter'], + ['sub-graph', 'reporter'], + ] + + // ── Run 1: cancel sub-graph so only researcher + analyst complete ── + const graph1 = new Graph({ + id: graphId, + nodes: createNodes(), + edges, + sessionManager: makeSessionManager(sessionId, tempDir), + }) + + graph1.addHook(BeforeNodeCallEvent, (event) => { + if (event.nodeId === 'sub-graph') { + event.cancel = 'simulated crash' + } + }) + + const result1 = await graph1.invoke('What is the tallest mountain in the world?') + + const completedRun1 = result1.results.filter((r) => r.status === Status.COMPLETED).map((r) => r.nodeId) + expect(completedRun1).toContain('researcher') + expect(completedRun1).toContain('analyst') + expect(completedRun1).not.toContain('sub-graph') + expect(completedRun1).not.toContain('reporter') + + // Verify sessionManager property is accessible + expect(graph1.sessionManager).toBeDefined() + + // ── Run 2: fresh Graph + SessionManager, no cancel hook ── + const graph2 = new Graph({ + id: graphId, + nodes: createNodes(), + edges, + sessionManager: makeSessionManager(sessionId, tempDir), + }) + + const result2 = await graph2.invoke('What is the tallest mountain in the world?') + + const completedRun2 = result2.results.filter((r) => r.status === Status.COMPLETED).map((r) => r.nodeId) + + // Sub-graph and reporter should now be completed + expect(completedRun2).toContain('sub-graph') + expect(completedRun2).toContain('reporter') + + // Researcher and analyst should not be re-executed (exactly one COMPLETED each) + expect(completedRun2.filter((id) => id === 'researcher')).toHaveLength(1) + expect(completedRun2.filter((id) => id === 'analyst')).toHaveLength(1) + + // All completed nodes produced content + for (const nodeResult of result2.results.filter((r) => r.status === Status.COMPLETED)) { + expect(nodeResult.content.length).toBeGreaterThan(0) + } + + // Reporter is the terminus — verify it received data from both branches + // (analyst from run 1, sub-graph from run 2) by checking for topic-relevant content + const reporterText = result2.results + .filter((r) => r.nodeId === 'reporter' && r.status === Status.COMPLETED) + .flatMap((r) => r.content) + .find((b) => b.type === 'textBlock')?.text + expect(reporterText).toBeTruthy() + expect(reporterText).toMatch(/Everest|mountain|tallest/i) + }) + + /** + * Graph topology with conditional edge: + * + * researcher ──→ writer (conditional: only if app state has 'approved' flag) + * │ ↑ + * └──→ analyst ──┘ (unconditional) + * + * - `researcher` and `analyst` are sources (no incoming edges). + * - `writer` has an AND-join: needs both researcher and analyst COMPLETED, + * AND the researcher→writer conditional edge handler to return true. + * + * Run 1: researcher and analyst both complete normally. But the conditional + * edge handler checks `state.app.get('approved')` which is not set, so + * _findReady evaluates the handler → false → writer is blocked. + * All deps are COMPLETED but the handler rejects the transition. + * + * Run 2 (resume): state is restored (researcher=COMPLETED, analyst=COMPLETED, + * writer=PENDING). A BeforeMultiAgentInvocationEvent hook sets approved=true. + * _findResumeTargets evaluates the handler via _allDependenciesSatisfied + * → true → writer is ready and executes. + * + * This directly tests that _findResumeTargets evaluates edge handlers, + * not just source node statuses. + */ + it('resumes with conditional edge handlers evaluated correctly', async () => { + const sessionId = uuidv7() + const graphId = 'resume-conditional' + + function makeAgent(id: string, prompt: string) { + return new Agent({ model: createModel(), printer: false, id, systemPrompt: prompt }) + } + + function createNodes() { + return [ + makeAgent('researcher', 'You are a researcher. State one fact about the topic.'), + makeAgent('analyst', 'You are an analyst. Add one supporting detail about the topic.'), + makeAgent( + 'writer', + 'You are a writer. Write a polished one-sentence answer. Mention "Everest" if the topic is about mountains.' + ), + ] + } + + const edges: EdgeDefinition[] = [ + { + source: 'researcher', + target: 'writer', + handler: (state: MultiAgentState) => state.app.get('approved') === true, + }, + ['analyst', 'writer'], + ] + + // ── Run 1: no approval flag → writer blocked by handler despite all deps COMPLETED ── + const graph1 = new Graph({ + id: graphId, + nodes: createNodes(), + edges, + sessionManager: makeSessionManager(sessionId, tempDir), + }) + + const result1 = await graph1.invoke('What is the tallest mountain?') + + const completedRun1 = result1.results.filter((r) => r.status === Status.COMPLETED).map((r) => r.nodeId) + expect(completedRun1).toContain('researcher') + expect(completedRun1).toContain('analyst') + // Writer should NOT have run — both deps are COMPLETED but the handler returned false + expect(completedRun1).not.toContain('writer') + + // ── Run 2: set approval flag before resume so handler passes ── + const graph2 = new Graph({ + id: graphId, + nodes: createNodes(), + edges, + sessionManager: makeSessionManager(sessionId, tempDir), + }) + + // Initialize first so the session manager's restore hook is registered, + // then add our hook — hooks run in registration order, so restore happens + // before we set the flag. + await graph2.initialize() + graph2.addHook(BeforeMultiAgentInvocationEvent, (event) => { + event.state.app.set('approved', true) + }) + + const result2 = await graph2.invoke('What is the tallest mountain?') + + const completedRun2 = result2.results.filter((r) => r.status === Status.COMPLETED).map((r) => r.nodeId) + expect(completedRun2).toContain('writer') + + // Researcher and analyst should not be re-executed + expect(completedRun2.filter((id) => id === 'researcher')).toHaveLength(1) + expect(completedRun2.filter((id) => id === 'analyst')).toHaveLength(1) + + const writerText = result2.results + .filter((r) => r.nodeId === 'writer' && r.status === Status.COMPLETED) + .flatMap((r) => r.content) + .find((b) => b.type === 'textBlock')?.text + expect(writerText).toBeTruthy() + expect(writerText).toMatch(/Everest|mountain|tallest/i) + }) +}) diff --git a/strands-ts/test/integ/multiagent/swarm.test.ts b/strands-ts/test/integ/multiagent/swarm.test.ts new file mode 100644 index 0000000000..7b7013733d --- /dev/null +++ b/strands-ts/test/integ/multiagent/swarm.test.ts @@ -0,0 +1,88 @@ +import { describe, expect, it } from 'vitest' +import { Agent } from '$/sdk/agent/agent.js' +import { Swarm, Status } from '$/sdk/multiagent/index.js' +import { collectGenerator } from '$/sdk/__fixtures__/model-test-helpers.js' +import { bedrock } from '../__fixtures__/model-providers.js' + +describe.skipIf(bedrock.skip)('Swarm', () => { + const createModel = (maxTokens = 1024) => bedrock.createModel({ maxTokens }) + + it('completes single-agent execution with lifecycle events', async () => { + const agent = new Agent({ + model: createModel(), + printer: false, + id: 'assistant', + description: 'Answers questions briefly.', + systemPrompt: 'Answer in one word only.', + }) + + const swarm = new Swarm({ + nodes: [agent], + start: 'assistant', + }) + + const { items, result } = await collectGenerator(swarm.stream('What is the capital of France?')) + + expect(result.status).toBe(Status.COMPLETED) + expect(result.results).toHaveLength(1) + expect(result.results[0]!.nodeId).toBe('assistant') + expect(result.duration).toBeGreaterThan(0) + + const text = result.content.find((b) => b.type === 'textBlock') + expect(text?.text).toMatch(/Paris/i) + + // Verify lifecycle events + const eventTypes = items.map((e) => e.type) + expect(eventTypes[0]).toBe('beforeMultiAgentInvocationEvent') + expect(eventTypes).toContain('beforeNodeCallEvent') + expect(eventTypes).toContain('nodeStreamUpdateEvent') + expect(eventTypes).toContain('nodeResultEvent') + expect(eventTypes).toContain('afterNodeCallEvent') + expect(eventTypes).toContain('afterMultiAgentInvocationEvent') + expect(eventTypes).toContain('multiAgentResultEvent') + }) + + it('hands off between agents with handoff event', async () => { + const researcher = new Agent({ + model: createModel(), + printer: false, + id: 'researcher', + description: 'Researches a topic then hands off to the writer.', + systemPrompt: + 'You are a researcher. Look up the answer, then always hand off to the writer agent. Never produce a final response yourself.', + }) + + const writer = new Agent({ + model: createModel(), + printer: false, + id: 'writer', + description: 'Writes a final one-sentence answer.', + systemPrompt: 'Write the final answer in one sentence. Do not hand off to another agent.', + }) + + const swarm = new Swarm({ + nodes: [researcher, writer], + start: 'researcher', + maxSteps: 4, + }) + + const { items, result } = await collectGenerator(swarm.stream('What is the largest ocean?')) + + expect(result.status).toBe(Status.COMPLETED) + expect(result.results.length).toBeGreaterThanOrEqual(2) + expect(result.results[0]!.nodeId).toBe('researcher') + expect(result.duration).toBeGreaterThan(0) + + const text = result.content.find((b) => b.type === 'textBlock') + expect(text?.text).toMatch(/Pacific/i) + + // Verify handoff event + const handoff = items.find((e) => e.type === 'multiAgentHandoffEvent') + expect(handoff).toEqual( + expect.objectContaining({ + source: 'researcher', + targets: ['writer'], + }) + ) + }) +}) diff --git a/strands-ts/test/integ/session-manager.test.node.ts b/strands-ts/test/integ/session-manager.test.node.ts new file mode 100644 index 0000000000..522dcc0a1a --- /dev/null +++ b/strands-ts/test/integ/session-manager.test.node.ts @@ -0,0 +1,343 @@ +/** + * Integration tests for session management. + */ +import { describe, expect, it, beforeAll, afterAll } from 'vitest' +import { promises as fs } from 'fs' +import { join } from 'path' +import { tmpdir } from 'os' +import { inject } from 'vitest' +import { v7 as uuidv7 } from 'uuid' +import { Agent } from '$/sdk/agent/agent.js' +import { + S3Client, + CreateBucketCommand, + DeleteBucketCommand, + DeleteObjectsCommand, + ListObjectsV2Command, +} from '@aws-sdk/client-s3' +import { STSClient, GetCallerIdentityCommand } from '@aws-sdk/client-sts' +import { SessionManager } from '$/sdk/session/session-manager.js' +import { FileStorage } from '$/sdk/session/file-storage.js' +import { S3Storage } from '$/sdk/session/s3-storage.js' +import { bedrock, openaiResponses } from './__fixtures__/model-providers.js' + +// ─── Helpers ───────────────────────────────────────────────────────────────── + +const AWS_REGION = process.env.AWS_REGION ?? 'us-east-1' + +async function getBucketName(credentials: any): Promise { + const sts = new STSClient({ region: AWS_REGION, credentials }) + const { Account } = await sts.send(new GetCallerIdentityCommand({})) + const suffix = Math.random().toString(16).slice(2, 8) + return `test-strands-session-${Account}-${AWS_REGION}-${suffix}` +} + +function makeFileManager(sessionId: string, storageDir: string): SessionManager { + return new SessionManager({ sessionId, storage: { snapshot: new FileStorage(storageDir) } }) +} + +function makeS3Manager(sessionId: string, bucket: string, credentials: any): SessionManager { + return new SessionManager({ + sessionId, + storage: { snapshot: new S3Storage({ bucket, s3Client: new S3Client({ region: AWS_REGION, credentials }) }) }, + }) +} + +async function getPersistedMessageCount(manager: SessionManager): Promise { + const snap = await (manager as any)._storage.snapshot.loadSnapshot({ + location: (manager as any)._location({ id: 'agent' }), + }) + return (snap?.data?.messages as unknown[])?.length ?? 0 +} + +// ─── File Storage Tests ─────────────────────────────────────────────────────── + +describe.skipIf(bedrock.skip)('Session Management - FileStorage', () => { + let tempDir: string + + beforeAll(async () => { + tempDir = join(tmpdir(), `strands-session-integ-${Date.now()}`) + await fs.mkdir(tempDir, { recursive: true }) + }) + + afterAll(async () => { + await fs.rm(tempDir, { recursive: true, force: true }) + }) + + it('persists and restores agent messages across sessions', async () => { + const sessionId = uuidv7() + const model = bedrock.createModel() + + const manager1 = makeFileManager(sessionId, tempDir) + const agent1 = new Agent({ model, sessionManager: manager1, printer: false }) + await agent1.invoke('Hello!') + expect(agent1.messages).toHaveLength(2) + expect(await getPersistedMessageCount(manager1)).toBe(2) + + const manager2 = makeFileManager(sessionId, tempDir) + const agent2 = new Agent({ model, sessionManager: manager2, printer: false }) + await agent2.initialize() + expect(agent2.messages).toHaveLength(2) + + await agent2.invoke('Hello again!') + expect(agent2.messages).toHaveLength(4) + expect(await getPersistedMessageCount(manager2)).toBe(4) + }) + + it('preserves conversation context across sessions', async () => { + const sessionId = uuidv7() + const model = bedrock.createModel() + + const manager1 = makeFileManager(sessionId, tempDir) + const agent1 = new Agent({ model, sessionManager: manager1, printer: false }) + await agent1.invoke('My name is Alice') + await agent1.invoke('What is my name?') + expect(agent1.messages).toHaveLength(4) + + const manager2 = makeFileManager(sessionId, tempDir) + const agent2 = new Agent({ model, sessionManager: manager2, printer: false }) + await agent2.initialize() + expect(agent2.messages).toHaveLength(4) + + const result = await agent2.invoke('Repeat my name') + const text = result.lastMessage.content.find((b) => b.type === 'textBlock') + expect(text?.text).toMatch(/Alice/i) + }) + + it('deleteSession removes all session data', async () => { + const sessionId = uuidv7() + const model = bedrock.createModel() + const manager = makeFileManager(sessionId, tempDir) + const agent = new Agent({ model, sessionManager: manager, printer: false }) + await agent.invoke('Hello!') + expect(await getPersistedMessageCount(manager)).toBe(2) + + await manager.deleteSession() + + const sessionDir = join(tempDir, sessionId) + await expect(fs.access(sessionDir)).rejects.toThrow() + }) + + it('creates immutable snapshots, verifies storage layout, and restores from specific snapshot', async () => { + const sessionId = uuidv7() + const model = bedrock.createModel() + const storage = new FileStorage(tempDir) + + const manager1 = new SessionManager({ sessionId, storage: { snapshot: storage }, snapshotTrigger: () => true }) + const agent1 = new Agent({ model, sessionManager: manager1, printer: false }) + await agent1.invoke('First message') // snapshot 1: 2 messages + await agent1.invoke('Second message') // snapshot 2: 4 messages + expect(agent1.messages).toHaveLength(4) + + // Verify storage layout + const base = join(tempDir, sessionId, 'scopes', 'agent', 'agent', 'snapshots') + await expect(fs.access(join(base, 'snapshot_latest.json'))).resolves.toBeUndefined() + const files = await fs.readdir(join(base, 'immutable_history')) + expect(files).toHaveLength(2) + expect(files.every((f) => /^snapshot_[\w-]+\.json$/.test(f))).toBe(true) + + // Restore from snapshot 1 — should only have 2 messages + const snapshotIds = await storage.listSnapshotIds({ location: { sessionId, scope: 'agent', scopeId: 'agent' } }) + expect(snapshotIds[0]).toBeDefined() + const sessionManager2 = new SessionManager({ + sessionId, + storage: { snapshot: storage }, + }) + const agent2 = new Agent({ + model, + sessionManager: sessionManager2, + printer: false, + }) + await agent2.initialize() + await sessionManager2.restoreSnapshot({ target: agent2, snapshotId: snapshotIds[0]! }) + expect(agent2.messages).toHaveLength(2) + }) +}) + +// ─── Stateful Model Tests ───────────────────────────────────────────────────── + +describe.skipIf(openaiResponses.skip)('Session Management - stateful model (OpenAI Responses)', () => { + let tempDir: string + + beforeAll(async () => { + tempDir = join(tmpdir(), `strands-session-stateful-integ-${Date.now()}`) + await fs.mkdir(tempDir, { recursive: true }) + }) + + afterAll(async () => { + await fs.rm(tempDir, { recursive: true, force: true }) + }) + + it('persists modelState.responseId and restores a usable stateful agent', async () => { + const sessionId = uuidv7() + const manager1 = makeFileManager(sessionId, tempDir) + const agent1 = new Agent({ + model: openaiResponses.createModel({ modelId: 'gpt-5.4-mini', stateful: true }), + sessionManager: manager1, + printer: false, + systemPrompt: 'Reply in one short sentence.', + }) + + await agent1.invoke('Hello.') + // Stateful invariant: server owns history, local messages stay empty. + expect(agent1.messages).toEqual([]) + const firstResponseId = agent1.modelState.get('responseId') + expect(firstResponseId).toEqual(expect.any(String)) + + // Persisted snapshot must reflect both: empty messages and the captured responseId. + const snap1 = await (manager1 as any)._storage.snapshot.loadSnapshot({ + location: (manager1 as any)._location({ id: 'agent' }), + }) + expect(snap1?.data?.messages).toEqual([]) + expect(snap1?.data?.modelState).toEqual({ responseId: firstResponseId }) + + // Reload into a fresh agent/manager pair backed by the same storage. + const manager2 = makeFileManager(sessionId, tempDir) + const agent2 = new Agent({ + model: openaiResponses.createModel({ modelId: 'gpt-5.4-mini', stateful: true }), + sessionManager: manager2, + printer: false, + systemPrompt: 'Reply in one short sentence.', + }) + await agent2.initialize() + + expect(agent2.messages).toEqual([]) + expect(agent2.modelState.get('responseId')).toBe(firstResponseId) + + // The restored agent must be able to continue the conversation. We only + // assert mechanical outcomes — no model-output string checks, so no flake surface. + const turn2 = await agent2.invoke('Say something brief.') + expect(turn2.stopReason).toBe('endTurn') + expect(agent2.messages).toEqual([]) + expect(agent2.modelState.get('responseId')).toEqual(expect.any(String)) + expect(agent2.modelState.get('responseId')).not.toBe(firstResponseId) + }) +}) + +// ─── S3 Storage Tests ───────────────────────────────────────────────────────── + +describe.skipIf(bedrock.skip)('Session Management - S3Storage', () => { + let bucket: string + let credentials: any + let s3: S3Client + + beforeAll(async () => { + credentials = inject('provider-bedrock')?.credentials + bucket = await getBucketName(credentials) + s3 = new S3Client({ region: AWS_REGION, credentials }) + try { + await s3.send( + new CreateBucketCommand({ + Bucket: bucket, + ...(AWS_REGION !== 'us-east-1' && { CreateBucketConfiguration: { LocationConstraint: AWS_REGION as any } }), + }) + ) + } catch (e: any) { + if (e?.name !== 'BucketAlreadyOwnedByYou') throw e + } + }) + + afterAll(async () => { + // Delete all objects then the bucket + let token: string | undefined + do { + const list = await s3.send(new ListObjectsV2Command({ Bucket: bucket, ContinuationToken: token })) + const objects = list.Contents?.map((o) => ({ Key: o.Key! })) ?? [] + if (objects.length) await s3.send(new DeleteObjectsCommand({ Bucket: bucket, Delete: { Objects: objects } })) + token = list.NextContinuationToken + } while (token) + await s3.send(new DeleteBucketCommand({ Bucket: bucket })) + }) + + it('persists and restores agent messages across sessions', async () => { + const sessionId = uuidv7() + const model = bedrock.createModel() + + const manager1 = makeS3Manager(sessionId, bucket, credentials) + const agent1 = new Agent({ model, sessionManager: manager1, printer: false }) + await agent1.invoke('Hello!') + expect(agent1.messages).toHaveLength(2) + expect(await getPersistedMessageCount(manager1)).toBe(2) + + const manager2 = makeS3Manager(sessionId, bucket, credentials) + const agent2 = new Agent({ model, sessionManager: manager2, printer: false }) + await agent2.initialize() + expect(agent2.messages).toHaveLength(2) + + await agent2.invoke('Hello again!') + expect(agent2.messages).toHaveLength(4) + expect(await getPersistedMessageCount(manager2)).toBe(4) + }) + + it('preserves conversation context across sessions', async () => { + const sessionId = uuidv7() + const model = bedrock.createModel() + + const manager1 = makeS3Manager(sessionId, bucket, credentials) + const agent1 = new Agent({ model, sessionManager: manager1, printer: false }) + await agent1.invoke('My name is Bob') + await agent1.invoke('What is my name?') + expect(agent1.messages).toHaveLength(4) + + const manager2 = makeS3Manager(sessionId, bucket, credentials) + const agent2 = new Agent({ model, sessionManager: manager2, printer: false }) + await agent2.initialize() + expect(agent2.messages).toHaveLength(4) + + const result = await agent2.invoke('Repeat my name') + const text = result.lastMessage.content.find((b) => b.type === 'textBlock') + expect(text?.text).toMatch(/Bob/i) + }) + + it('deleteSession removes all session data from S3', async () => { + const sessionId = uuidv7() + const model = bedrock.createModel() + const manager = makeS3Manager(sessionId, bucket, credentials) + const agent = new Agent({ model, sessionManager: manager, printer: false }) + await agent.invoke('Hello!') + expect(await getPersistedMessageCount(manager)).toBe(2) + + await manager.deleteSession() + + const list = await s3.send(new ListObjectsV2Command({ Bucket: bucket, Prefix: `${sessionId}/` })) + expect(list.Contents ?? []).toHaveLength(0) + }) + + it('creates immutable snapshots and supports time-travel restore', async () => { + const sessionId = uuidv7() + const model = bedrock.createModel() + + const manager1 = new SessionManager({ + sessionId, + storage: { snapshot: new S3Storage({ bucket, s3Client: new S3Client({ region: AWS_REGION, credentials }) }) }, + snapshotTrigger: ({ agentData }) => agentData.messages.length === 4, + saveLatestOn: 'invocation', + }) + const agent1 = new Agent({ model, sessionManager: manager1, printer: false }) + await agent1.invoke('What is 10 + 5?') // 2 messages — no snapshot + await agent1.invoke('What is 20 * 3?') // 4 messages — snapshot 1 + await agent1.invoke('What is 100 / 4?') // 6 messages — no snapshot + await agent1.invoke('What is 50 - 15?') // 8 messages — no snapshot + expect(agent1.messages).toHaveLength(8) + + // Verify UUID-based S3 key naming and restore from snapshot 1 (after turn 2) + const s3Storage = new S3Storage({ bucket, s3Client: new S3Client({ region: AWS_REGION, credentials }) }) + const snapshotIds = await s3Storage.listSnapshotIds({ location: { sessionId, scope: 'agent', scopeId: 'agent' } }) + expect(snapshotIds).toHaveLength(1) + expect(snapshotIds.every((id) => /^[\w-]{36}$/.test(id))).toBe(true) + expect(snapshotIds[0]).toBeDefined() + const s3Manager2 = new SessionManager({ + sessionId, + storage: { snapshot: s3Storage }, + saveLatestOn: 'trigger', + }) + const agent2 = new Agent({ model, sessionManager: s3Manager2, printer: false }) + await agent2.initialize() + await s3Manager2.restoreSnapshot({ target: agent2, snapshotId: snapshotIds[0]! }) + expect(agent2.messages).toHaveLength(4) + + const result = await agent2.invoke('What was my last question?') + const text = result.lastMessage.content.find((b) => b.type === 'textBlock') + expect(text?.text).toMatch(/20.*3|multiply|60/i) + }) +}) diff --git a/strands-ts/test/integ/skills/agent-skills.test.node.ts b/strands-ts/test/integ/skills/agent-skills.test.node.ts new file mode 100644 index 0000000000..cadb40147f --- /dev/null +++ b/strands-ts/test/integ/skills/agent-skills.test.node.ts @@ -0,0 +1,178 @@ +import { describe, expect, it, beforeEach, afterEach } from 'vitest' +import { Agent } from '$/sdk/index.js' +import { AgentSkills, Skill } from '$/sdk/vended-plugins/skills/index.js' +import { getMessageText } from '../__fixtures__/model-test-helpers.js' +import { bedrock } from '../__fixtures__/model-providers.js' +import { promises as fs } from 'fs' +import * as path from 'path' +import { tmpdir } from 'os' + +describe.skipIf(bedrock.skip)('AgentSkills Plugin Integration', () => { + const SECRET_CODEWORD = 'PAPAYA7' + const ALT_SECRET_CODEWORD = 'MANGO9' + + const summarizationSkill = new Skill({ + name: 'summarization', + description: 'Summarize text into concise bullet points', + instructions: `You are a summarization expert. When this skill is activated, summarize the provided text. +The secret codeword is: ${SECRET_CODEWORD}. Always include this codeword in your response.`, + }) + + const translationSkill = new Skill({ + name: 'translation', + description: 'Translate text between languages', + instructions: 'You are a translation expert. Translate the provided text accurately.', + }) + + const createModel = () => + bedrock.createModel({ + maxTokens: 1024, + }) + + describe('agent activates skill and uses instructions', () => { + it('activates a skill via prompt and includes the secret codeword', async () => { + const plugin = new AgentSkills({ + skills: [summarizationSkill, translationSkill], + }) + + const agent = new Agent({ + model: createModel(), + plugins: [plugin], + printer: false, + }) + + const result = await agent.invoke( + 'Activate the summarization skill and tell me the secret codeword from its instructions.' + ) + + const responseText = getMessageText(result.lastMessage) + + // Verify the model used the skills tool + const toolUseMessage = agent.messages.find((msg) => + msg.content.some((block) => block.type === 'toolUseBlock' && block.name === 'skills') + ) + expect(toolUseMessage).toBeDefined() + + // Verify the model found the secret codeword from the skill instructions + expect(responseText).toContain(SECRET_CODEWORD) + + // Verify the system prompt has skill metadata injected + const systemPrompt = agent.systemPrompt as string + expect(systemPrompt).toContain('') + expect(systemPrompt).toContain('summarization') + expect(systemPrompt).toContain('translation') + }) + }) + + describe('skill activation state persistence', () => { + it('tracks activated skills in agent appState', async () => { + const plugin = new AgentSkills({ + skills: [summarizationSkill, translationSkill], + }) + + const agent = new Agent({ + model: createModel(), + plugins: [plugin], + printer: false, + }) + + // Activate the first skill + await agent.invoke('Activate the summarization skill.') + let activated = plugin.getActivatedSkills(agent) + expect(activated).toContain('summarization') + + // Activate the second skill + await agent.invoke('Now activate the translation skill.') + activated = plugin.getActivatedSkills(agent) + expect(activated).toContain('summarization') + expect(activated).toContain('translation') + }) + }) + + describe('load skills from filesystem', () => { + let testDir: string + + beforeEach(async () => { + testDir = path.join(tmpdir(), `skills-integ-test-${Date.now()}-${Math.random().toString(36).slice(2)}`) + await fs.mkdir(testDir, { recursive: true }) + }) + + afterEach(async () => { + await fs.rm(testDir, { recursive: true, force: true }) + }) + + it('loads a skill from disk and activates it', async () => { + // Create a skill directory with SKILL.md + const skillDir = path.join(testDir, 'code-review') + await fs.mkdir(skillDir, { recursive: true }) + await fs.writeFile( + path.join(skillDir, 'SKILL.md'), + `--- +name: code-review +description: Review code for bugs and improvements +--- +You are a code review expert. When reviewing code, look for bugs, security issues, and performance improvements. +The secret codeword for this skill is: ${ALT_SECRET_CODEWORD}.`, + 'utf-8' + ) + + const plugin = new AgentSkills({ + skills: [testDir], + }) + + // Verify the skill was loaded from the directory + const availableSkills = await plugin.getAvailableSkills() + expect(availableSkills).toHaveLength(1) + expect(availableSkills[0]!.name).toBe('code-review') + + const agent = new Agent({ + model: createModel(), + plugins: [plugin], + printer: false, + }) + + const result = await agent.invoke( + 'Activate the code-review skill and tell me the secret codeword from its instructions.' + ) + + const responseText = getMessageText(result.lastMessage) + expect(responseText).toContain(ALT_SECRET_CODEWORD) + }) + }) + + describe('system prompt marker replacement', () => { + it('replaces the skills block with updated content between invocations', async () => { + const plugin = new AgentSkills({ + skills: [summarizationSkill], + }) + + const agent = new Agent({ + model: createModel(), + plugins: [plugin], + printer: false, + systemPrompt: 'You are a helpful assistant.', + }) + + // First invocation — only summarization is available + await agent.invoke('Hello.') + + const promptAfterFirst = agent.systemPrompt as string + expect((promptAfterFirst.match(//g) ?? []).length).toBe(1) + expect(promptAfterFirst).toContain('You are a helpful assistant.') + expect(promptAfterFirst).toContain('summarization') + expect(promptAfterFirst).not.toContain('translation') + + // Swap the skill set between invocations + plugin.setAvailableSkills([translationSkill]) + + // Second invocation — only translation should appear, summarization gone + await agent.invoke('Hello again.') + + const promptAfterSecond = agent.systemPrompt as string + expect((promptAfterSecond.match(//g) ?? []).length).toBe(1) + expect(promptAfterSecond).toContain('You are a helpful assistant.') + expect(promptAfterSecond).toContain('translation') + expect(promptAfterSecond).not.toContain('summarization') + }) + }) +}) diff --git a/strands-ts/test/integ/telemetry.test.node.ts b/strands-ts/test/integ/telemetry.test.node.ts new file mode 100644 index 0000000000..46aedd062c --- /dev/null +++ b/strands-ts/test/integ/telemetry.test.node.ts @@ -0,0 +1,947 @@ +import { describe, it, expect, beforeAll, beforeEach, afterAll } from 'vitest' +import { Agent, tool } from '@strands-agents/sdk' +import { getTracer, getMeter } from '@strands-agents/sdk/telemetry' +import { NodeTracerProvider } from '@opentelemetry/sdk-trace-node' +import { InMemorySpanExporter, SimpleSpanProcessor } from '@opentelemetry/sdk-trace-base' +import type { ReadableSpan } from '@opentelemetry/sdk-trace-base' +import { SpanStatusCode, trace, context, metrics as otelMetrics } from '@opentelemetry/api' +import { + MeterProvider, + InMemoryMetricExporter, + PeriodicExportingMetricReader, + AggregationTemporality, +} from '@opentelemetry/sdk-metrics' +import { z } from 'zod' +import { MockMessageModel } from '$/sdk/__fixtures__/mock-message-model.js' +import { TestModelProvider, collectGenerator } from '$/sdk/__fixtures__/model-test-helpers.js' +import { findMetricValue } from '$/sdk/__fixtures__/metrics-helpers.js' + +const AGENT_SPAN_PREFIX = 'invoke_agent' +const CYCLE_SPAN_NAME = 'execute_agent_loop_cycle' +const MODEL_SPAN_NAME = 'chat' +const TOOL_SPAN_PREFIX = 'execute_tool' + +// Shared provider and exporter — registered once, reset between tests +let provider: NodeTracerProvider +let exporter: InMemorySpanExporter + +function getSpans(): ReadableSpan[] { + return [...exporter.getFinishedSpans()].sort( + // Compare OTel HrTime [seconds, nanoseconds] — seconds first, then nanoseconds as tiebreaker + (a, b) => a.startTime[0] - b.startTime[0] || a.startTime[1] - b.startTime[1] + ) +} + +function findSpans(spans: ReadableSpan[], prefix: string): ReadableSpan[] { + return spans.filter((s) => s.name.startsWith(prefix)) +} + +function assertParentChild(parent: ReadableSpan, child: ReadableSpan): void { + expect(child.spanContext().traceId).toBe(parent.spanContext().traceId) + expect(child.parentSpanContext?.spanId).toBe(parent.spanContext().spanId) +} + +function attr(span: ReadableSpan, key: string): unknown { + return span.attributes[key] +} + +const calculatorTool = tool({ + name: 'calculator', + description: 'Add two numbers', + inputSchema: z.object({ a: z.number(), b: z.number() }), + callback: ({ a, b }) => `${a + b}`, +}) + +const failingTool = tool({ + name: 'failing_tool', + description: 'Always fails', + inputSchema: z.object({}), + callback: () => { + throw new Error('tool exploded') + }, +}) + +describe.sequential('Telemetry Integration', () => { + beforeAll(() => { + exporter = new InMemorySpanExporter() + provider = new NodeTracerProvider({ spanProcessors: [new SimpleSpanProcessor(exporter)] }) + provider.register() + }) + + beforeEach(() => { + exporter.reset() + }) + + afterAll(async () => { + await provider.forceFlush() + await provider.shutdown() + }) + + /** + * Flush and return all spans captured during the current test. + */ + async function flush(): Promise { + await provider.forceFlush() + return getSpans() + } + + describe('span hierarchy', () => { + it('creates agent → cycle → model spans for a simple invocation', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello back' }) + const agent = new Agent({ model, printer: false, name: 'hierarchy-agent' }) + + await agent.invoke('Hi') + + const spans = await flush() + const agentSpans = findSpans(spans, AGENT_SPAN_PREFIX) + const cycleSpans = findSpans(spans, CYCLE_SPAN_NAME) + const modelSpans = findSpans(spans, MODEL_SPAN_NAME) + + expect(agentSpans).toHaveLength(1) + expect(cycleSpans).toHaveLength(1) + expect(modelSpans).toHaveLength(1) + + // Verify span names + expect(agentSpans[0]!.name).toBe('invoke_agent hierarchy-agent') + expect(cycleSpans[0]!.name).toBe('execute_agent_loop_cycle') + expect(modelSpans[0]!.name).toBe('chat') + + assertParentChild(agentSpans[0]!, cycleSpans[0]!) + assertParentChild(cycleSpans[0]!, modelSpans[0]!) + }) + + it('creates tool spans nested under cycle spans', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'calculator', toolUseId: 'tool-1', input: { a: 1, b: 2 } }) + .addTurn({ type: 'textBlock', text: 'The answer is 3' }) + + const agent = new Agent({ model, printer: false, name: 'tool-agent', tools: [calculatorTool] }) + + await agent.invoke('Add 1 and 2') + + const spans = await flush() + const agentSpans = findSpans(spans, AGENT_SPAN_PREFIX) + const cycleSpans = findSpans(spans, CYCLE_SPAN_NAME) + const modelSpans = findSpans(spans, MODEL_SPAN_NAME) + const toolSpans = findSpans(spans, TOOL_SPAN_PREFIX) + + // Verify exact span counts and names + expect(agentSpans.map((s) => s.name)).toStrictEqual(['invoke_agent tool-agent']) + expect(cycleSpans).toHaveLength(2) + expect(modelSpans).toHaveLength(2) + expect(toolSpans.map((s) => s.name)).toStrictEqual(['execute_tool calculator']) + + // Both cycles parent to agent + assertParentChild(agentSpans[0]!, cycleSpans[0]!) + assertParentChild(agentSpans[0]!, cycleSpans[1]!) + + // Tool span parents to first cycle + assertParentChild(cycleSpans[0]!, toolSpans[0]!) + + // All spans share the same trace ID + const traceId = agentSpans[0]!.spanContext().traceId + for (const span of spans) { + expect(span.spanContext().traceId).toBe(traceId) + } + }) + + it('creates correct hierarchy for multi-tool invocation in a single cycle', async () => { + const echoTool = tool({ + name: 'echo', + description: 'Echo input', + inputSchema: z.object({ text: z.string() }), + callback: ({ text }) => text, + }) + + const model = new MockMessageModel() + .addTurn([ + { type: 'toolUseBlock', name: 'calculator', toolUseId: 'tool-1', input: { a: 1, b: 2 } }, + { type: 'toolUseBlock', name: 'echo', toolUseId: 'tool-2', input: { text: 'hello' } }, + ]) + .addTurn({ type: 'textBlock', text: 'Done' }) + + const agent = new Agent({ model, printer: false, name: 'multi-tool-agent', tools: [calculatorTool, echoTool] }) + + await agent.invoke('Do both') + + const spans = await flush() + const toolSpans = findSpans(spans, TOOL_SPAN_PREFIX) + const cycleSpans = findSpans(spans, CYCLE_SPAN_NAME) + + expect(toolSpans.map((s) => s.name)).toStrictEqual(['execute_tool calculator', 'execute_tool echo']) + assertParentChild(cycleSpans[0]!, toolSpans[0]!) + assertParentChild(cycleSpans[0]!, toolSpans[1]!) + }) + }) + + describe('span attributes', () => { + it('sets agent span attributes correctly', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const agent = new Agent({ + model, + printer: false, + name: 'attr-agent', + systemPrompt: 'You are helpful', + tools: [calculatorTool], + traceAttributes: { 'app.custom': 'value' }, + }) + + await agent.invoke('Hello') + + const spans = await flush() + const agentSpan = findSpans(spans, AGENT_SPAN_PREFIX)[0]! + + expect(attr(agentSpan, 'gen_ai.operation.name')).toBe('invoke_agent') + expect(attr(agentSpan, 'gen_ai.agent.name')).toBe('attr-agent') + expect(attr(agentSpan, 'gen_ai.request.model')).toBe('test-model') + expect(attr(agentSpan, 'app.custom')).toBe('value') + expect(attr(agentSpan, 'system_prompt')).toBe('"You are helpful"') + + const toolNames = attr(agentSpan, 'gen_ai.agent.tools') as string + expect(JSON.parse(toolNames)).toStrictEqual(['calculator']) + }) + + it('sets model span attributes correctly', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Response' }) + const agent = new Agent({ model, printer: false, name: 'model-attr-agent' }) + + await agent.invoke('Hello') + + const spans = await flush() + const modelSpan = findSpans(spans, MODEL_SPAN_NAME)[0]! + + expect(attr(modelSpan, 'gen_ai.operation.name')).toBe('chat') + expect(attr(modelSpan, 'gen_ai.request.model')).toBe('test-model') + + const choiceEvent = modelSpan.events.find((e) => e.name === 'gen_ai.choice') + expect(choiceEvent).toBeDefined() + expect(JSON.parse(choiceEvent!.attributes!['message'] as string)).toStrictEqual([{ text: 'Response' }]) + }) + + it('sets tool span attributes correctly', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'calculator', toolUseId: 'tool-42', input: { a: 5, b: 3 } }) + .addTurn({ type: 'textBlock', text: '8' }) + + const agent = new Agent({ model, printer: false, name: 'tool-attr-agent', tools: [calculatorTool] }) + + await agent.invoke('Add 5 and 3') + + const spans = await flush() + const toolSpan = findSpans(spans, TOOL_SPAN_PREFIX)[0]! + + expect(attr(toolSpan, 'gen_ai.operation.name')).toBe('execute_tool') + expect(attr(toolSpan, 'gen_ai.tool.name')).toBe('calculator') + expect(attr(toolSpan, 'gen_ai.tool.call.id')).toBe('tool-42') + + const choiceEvent = toolSpan.events.find((e) => e.name === 'gen_ai.choice') + expect(choiceEvent).toBeDefined() + expect(choiceEvent!.attributes!['id']).toBe('tool-42') + expect(JSON.parse(choiceEvent!.attributes!['message'] as string)).toStrictEqual([{ text: '8' }]) + }) + + it('sets cycle span attributes correctly', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Done' }) + const agent = new Agent({ model, printer: false, name: 'cycle-attr-agent' }) + + await agent.invoke('Hello') + + const spans = await flush() + const cycleSpan = findSpans(spans, CYCLE_SPAN_NAME)[0]! + + expect(attr(cycleSpan, 'agent_loop.cycle_id')).toBe('cycle-1') + }) + }) + + describe('custom trace attributes', () => { + it('merges constructor-level trace attributes onto agent span', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const agent = new Agent({ + model, + printer: false, + name: 'custom-attr-agent', + traceAttributes: { 'app.module': 'weather', 'app.version': '1.0.0' }, + }) + + await agent.invoke('Hello') + + const spans = await flush() + const agentSpan = findSpans(spans, AGENT_SPAN_PREFIX)[0]! + + expect(attr(agentSpan, 'app.module')).toBe('weather') + expect(attr(agentSpan, 'app.version')).toBe('1.0.0') + }) + + it('traceAttributes override SDK-computed attributes for colliding keys', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const agent = new Agent({ + model, + printer: false, + name: 'override-agent', + traceAttributes: { 'gen_ai.agent.name': 'custom-name', 'gen_ai.request.model': 'custom-model' }, + }) + + await agent.invoke('Hello') + + const spans = await flush() + const agentSpan = findSpans(spans, AGENT_SPAN_PREFIX)[0]! + + expect(attr(agentSpan, 'gen_ai.agent.name')).toBe('custom-name') + expect(attr(agentSpan, 'gen_ai.request.model')).toBe('custom-model') + }) + }) + + describe('stop reason propagation', () => { + it('records stop reason in agent span response event', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Final answer' }) + const agent = new Agent({ model, printer: false, name: 'stop-reason-agent' }) + + await agent.invoke('Hello') + + const spans = await flush() + const agentSpan = findSpans(spans, AGENT_SPAN_PREFIX)[0]! + + const choiceEvent = agentSpan.events.find((e) => e.name === 'gen_ai.choice') + expect(choiceEvent).toBeDefined() + expect(choiceEvent!.attributes!['finish_reason']).toBe('endTurn') + expect(choiceEvent!.attributes!['message']).toBe('Final answer') + }) + + it('records stop reason in model span output event', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Response' }) + const agent = new Agent({ model, printer: false, name: 'model-stop-agent' }) + + await agent.invoke('Hello') + + const spans = await flush() + const modelSpan = findSpans(spans, MODEL_SPAN_NAME)[0]! + + const choiceEvent = modelSpan.events.find((e) => e.name === 'gen_ai.choice') + expect(choiceEvent).toBeDefined() + expect(choiceEvent!.attributes!['finish_reason']).toBe('endTurn') + + const message = JSON.parse(choiceEvent!.attributes!['message'] as string) + expect(message).toStrictEqual([{ text: 'Response' }]) + }) + }) + + describe('error handling', () => { + it('records error status on agent span when model throws', async () => { + const model = new MockMessageModel().addTurn(new Error('Model failed')) + const agent = new Agent({ model, printer: false, name: 'error-agent' }) + + await expect(agent.invoke('Hello')).rejects.toThrow() + + const spans = await flush() + const agentSpan = findSpans(spans, AGENT_SPAN_PREFIX)[0]! + + expect(agentSpan.status.code).toBe(SpanStatusCode.ERROR) + expect(agentSpan.status.message).toBe('Model failed') + }) + + it('records error status and exception event on model span when model throws', async () => { + const model = new MockMessageModel().addTurn(new Error('Model failed')) + const agent = new Agent({ model, printer: false, name: 'model-error-agent' }) + + await expect(agent.invoke('Hello')).rejects.toThrow() + + const spans = await flush() + const modelSpan = findSpans(spans, MODEL_SPAN_NAME)[0]! + + expect(modelSpan.status.code).toBe(SpanStatusCode.ERROR) + expect(modelSpan.status.message).toBe('Model failed') + + const exceptionEvent = modelSpan.events.find((e) => e.name === 'exception') + expect(exceptionEvent).toBeDefined() + expect(exceptionEvent!.attributes!['exception.message']).toBe('Model failed') + }) + + it('records error status and exception event on tool span when tool throws', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'failing_tool', toolUseId: 'tool-1', input: {} }) + .addTurn({ type: 'textBlock', text: 'Handled the error' }) + + const agent = new Agent({ model, printer: false, name: 'tool-error-agent', tools: [failingTool] }) + + await agent.invoke('Do something') + + const spans = await flush() + const toolSpan = findSpans(spans, TOOL_SPAN_PREFIX)[0]! + + expect(toolSpan.status.code).toBe(SpanStatusCode.ERROR) + expect(toolSpan.status.message).toBe('tool exploded') + + const exceptionEvent = toolSpan.events.find((e) => e.name === 'exception') + expect(exceptionEvent).toBeDefined() + expect(exceptionEvent!.attributes!['exception.message']).toBe('tool exploded') + }) + + it('records error on cycle span when model throws mid-loop', async () => { + const model = new MockMessageModel().addTurn(new Error('Cycle failure')) + const agent = new Agent({ model, printer: false, name: 'cycle-error-agent' }) + + await expect(agent.invoke('Hello')).rejects.toThrow() + + const spans = await flush() + const cycleSpan = findSpans(spans, CYCLE_SPAN_NAME)[0]! + + expect(cycleSpan.status.code).toBe(SpanStatusCode.ERROR) + expect(cycleSpan.status.message).toBe('Cycle failure') + }) + + it('sets OK status on all spans for successful invocations', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'All good' }) + const agent = new Agent({ model, printer: false, name: 'ok-agent' }) + + await agent.invoke('Hello') + + const spans = await flush() + for (const span of spans) { + expect(span.status.code).toBe(SpanStatusCode.OK) + } + }) + }) + + describe('multi-cycle agent loops', () => { + it('creates separate cycle spans for each loop iteration', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'calculator', toolUseId: 'tool-1', input: { a: 1, b: 2 } }) + .addTurn({ type: 'toolUseBlock', name: 'calculator', toolUseId: 'tool-2', input: { a: 3, b: 4 } }) + .addTurn({ type: 'textBlock', text: 'All done' }) + + const agent = new Agent({ model, printer: false, name: 'multi-cycle-agent', tools: [calculatorTool] }) + + await agent.invoke('Do two calculations') + + const spans = await flush() + const agentSpans = findSpans(spans, AGENT_SPAN_PREFIX) + const cycleSpans = findSpans(spans, CYCLE_SPAN_NAME) + const modelSpans = findSpans(spans, MODEL_SPAN_NAME) + const toolSpans = findSpans(spans, TOOL_SPAN_PREFIX) + + expect(agentSpans.map((s) => s.name)).toStrictEqual(['invoke_agent multi-cycle-agent']) + expect(cycleSpans).toHaveLength(3) + expect(modelSpans).toHaveLength(3) + expect(toolSpans.map((s) => s.name)).toStrictEqual(['execute_tool calculator', 'execute_tool calculator']) + + expect(cycleSpans.map((s) => attr(s, 'agent_loop.cycle_id'))).toStrictEqual(['cycle-1', 'cycle-2', 'cycle-3']) + + for (const cycle of cycleSpans) { + assertParentChild(agentSpans[0]!, cycle) + } + }) + }) + + describe('streaming', () => { + it('creates the same span hierarchy when using stream()', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'calculator', toolUseId: 'tool-1', input: { a: 2, b: 3 } }) + .addTurn({ type: 'textBlock', text: '5' }) + + const agent = new Agent({ model, printer: false, name: 'stream-agent', tools: [calculatorTool] }) + + await collectGenerator(agent.stream('Add 2 and 3')) + + const spans = await flush() + const agentSpans = findSpans(spans, AGENT_SPAN_PREFIX) + const cycleSpans = findSpans(spans, CYCLE_SPAN_NAME) + const modelSpans = findSpans(spans, MODEL_SPAN_NAME) + const toolSpans = findSpans(spans, TOOL_SPAN_PREFIX) + + expect(agentSpans.map((s) => s.name)).toStrictEqual(['invoke_agent stream-agent']) + expect(cycleSpans).toHaveLength(2) + expect(modelSpans).toHaveLength(2) + expect(toolSpans.map((s) => s.name)).toStrictEqual(['execute_tool calculator']) + + assertParentChild(agentSpans[0]!, cycleSpans[0]!) + assertParentChild(agentSpans[0]!, cycleSpans[1]!) + assertParentChild(cycleSpans[0]!, toolSpans[0]!) + assertParentChild(cycleSpans[0]!, modelSpans[0]!) + assertParentChild(cycleSpans[1]!, modelSpans[1]!) + + // All spans OK + for (const span of spans) { + expect(span.status.code).toBe(SpanStatusCode.OK) + } + + // Verify tool output content + const toolSpan = toolSpans[0]! + const toolChoiceEvent = toolSpan.events.find((e) => e.name === 'gen_ai.choice') + expect(toolChoiceEvent).toBeDefined() + expect(JSON.parse(toolChoiceEvent!.attributes!['message'] as string)).toStrictEqual([{ text: '5' }]) + }) + }) + + describe('span timing', () => { + it('sets ISO 8601 start and end time attributes on all spans', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Done' }) + const agent = new Agent({ model, printer: false, name: 'timing-agent' }) + + await agent.invoke('Hello') + + const spans = await flush() + const isoPattern = /^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z$/ + for (const span of spans) { + const startTime = attr(span, 'gen_ai.event.start_time') as string + const endTime = attr(span, 'gen_ai.event.end_time') as string + expect(startTime).toMatch(isoPattern) + expect(endTime).toMatch(isoPattern) + expect(new Date(startTime).getTime()).toBeLessThanOrEqual(new Date(endTime).getTime()) + } + }) + }) + + describe('span events', () => { + it('records user message and response choice events on agent span', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello there' }) + const agent = new Agent({ model, printer: false, name: 'events-agent' }) + + await agent.invoke('Hi') + + const spans = await flush() + const agentSpan = findSpans(spans, AGENT_SPAN_PREFIX)[0]! + + const userEvent = agentSpan.events.find((e) => e.name === 'gen_ai.user.message') + expect(userEvent).toBeDefined() + const userContent = JSON.parse(userEvent!.attributes!['content'] as string) + expect(userContent).toStrictEqual([{ text: 'Hi' }]) + + const choiceEvent = agentSpan.events.find((e) => e.name === 'gen_ai.choice') + expect(choiceEvent).toBeDefined() + expect(choiceEvent!.attributes!['message']).toBe('Hello there') + expect(choiceEvent!.attributes!['finish_reason']).toBe('endTurn') + }) + + it('records tool input and output events with correct data on tool span', async () => { + const model = new MockMessageModel() + .addTurn({ type: 'toolUseBlock', name: 'calculator', toolUseId: 'tool-1', input: { a: 10, b: 20 } }) + .addTurn({ type: 'textBlock', text: '30' }) + + const agent = new Agent({ model, printer: false, name: 'tool-events-agent', tools: [calculatorTool] }) + + await agent.invoke('Add 10 and 20') + + const spans = await flush() + const toolSpan = findSpans(spans, TOOL_SPAN_PREFIX)[0]! + + const toolInputEvent = toolSpan.events.find((e) => e.name === 'gen_ai.tool.message') + expect(toolInputEvent).toBeDefined() + expect(toolInputEvent!.attributes!['role']).toBe('tool') + expect(JSON.parse(toolInputEvent!.attributes!['content'] as string)).toStrictEqual({ a: 10, b: 20 }) + expect(toolInputEvent!.attributes!['id']).toBe('tool-1') + + const toolOutputEvent = toolSpan.events.find((e) => e.name === 'gen_ai.choice') + expect(toolOutputEvent).toBeDefined() + expect(toolOutputEvent!.attributes!['id']).toBe('tool-1') + expect(JSON.parse(toolOutputEvent!.attributes!['message'] as string)).toStrictEqual([{ text: '30' }]) + }) + }) + + describe('token usage accumulation', () => { + it('records accumulated usage on agent span across multiple cycles', async () => { + let callCount = 0 + const model = new TestModelProvider(() => { + callCount++ + return (async function* () { + yield { type: 'modelMessageStartEvent' as const, role: 'assistant' as const } + + if (callCount === 1) { + // First call: tool use + yield { + type: 'modelContentBlockStartEvent' as const, + start: { type: 'toolUseStart' as const, name: 'calculator', toolUseId: 'tool-1' }, + } + yield { + type: 'modelContentBlockDeltaEvent' as const, + delta: { type: 'toolUseInputDelta' as const, input: '{"a":1,"b":2}' }, + } + yield { type: 'modelContentBlockStopEvent' as const } + yield { type: 'modelMessageStopEvent' as const, stopReason: 'toolUse' as const } + yield { + type: 'modelMetadataEvent' as const, + usage: { inputTokens: 100, outputTokens: 50, totalTokens: 150 }, + } + } else { + // Second call: text response + yield { type: 'modelContentBlockStartEvent' as const } + yield { + type: 'modelContentBlockDeltaEvent' as const, + delta: { type: 'textDelta' as const, text: 'The answer is 3' }, + } + yield { type: 'modelContentBlockStopEvent' as const } + yield { type: 'modelMessageStopEvent' as const, stopReason: 'endTurn' as const } + yield { + type: 'modelMetadataEvent' as const, + usage: { inputTokens: 200, outputTokens: 75, totalTokens: 275 }, + } + } + })() + }) + + const agent = new Agent({ model, printer: false, name: 'usage-agent', tools: [calculatorTool] }) + + await agent.invoke('Add 1 and 2') + + const spans = await flush() + const agentSpan = findSpans(spans, AGENT_SPAN_PREFIX)[0]! + + // Accumulated: 100+200=300 input, 50+75=125 output, 150+275=425 total + expect(attr(agentSpan, 'gen_ai.usage.input_tokens')).toBe(300) + expect(attr(agentSpan, 'gen_ai.usage.output_tokens')).toBe(125) + expect(attr(agentSpan, 'gen_ai.usage.total_tokens')).toBe(425) + // Legacy attribute names + expect(attr(agentSpan, 'gen_ai.usage.prompt_tokens')).toBe(300) + expect(attr(agentSpan, 'gen_ai.usage.completion_tokens')).toBe(125) + }) + + it('records per-call usage on individual model spans', async () => { + let callCount = 0 + const model = new TestModelProvider(() => { + callCount++ + return (async function* () { + yield { type: 'modelMessageStartEvent' as const, role: 'assistant' as const } + yield { type: 'modelContentBlockStartEvent' as const } + yield { + type: 'modelContentBlockDeltaEvent' as const, + delta: { type: 'textDelta' as const, text: `Response ${callCount}` }, + } + yield { type: 'modelContentBlockStopEvent' as const } + yield { type: 'modelMessageStopEvent' as const, stopReason: 'endTurn' as const } + yield { + type: 'modelMetadataEvent' as const, + usage: { inputTokens: callCount * 10, outputTokens: callCount * 5, totalTokens: callCount * 15 }, + } + })() + }) + + const agent = new Agent({ model, printer: false, name: 'model-usage-agent' }) + + await agent.invoke('Hello') + + const spans = await flush() + const modelSpan = findSpans(spans, MODEL_SPAN_NAME)[0]! + + expect(attr(modelSpan, 'gen_ai.usage.input_tokens')).toBe(10) + expect(attr(modelSpan, 'gen_ai.usage.output_tokens')).toBe(5) + expect(attr(modelSpan, 'gen_ai.usage.total_tokens')).toBe(15) + }) + }) + + describe('concurrent agents', () => { + it('creates isolated traces for concurrent agent invocations', async () => { + const model1 = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Agent 1 response' }) + const model2 = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Agent 2 response' }) + + const agent1 = new Agent({ model: model1, printer: false, name: 'agent-1' }) + const agent2 = new Agent({ model: model2, printer: false, name: 'agent-2' }) + + await Promise.all([agent1.invoke('Hello 1'), agent2.invoke('Hello 2')]) + + const spans = await flush() + const agentSpans = findSpans(spans, AGENT_SPAN_PREFIX) + + expect(agentSpans).toHaveLength(2) + + const spanNames = agentSpans.map((s) => s.name).sort() + expect(spanNames).toStrictEqual(['invoke_agent agent-1', 'invoke_agent agent-2']) + + // Each agent gets its own trace + const traceIds = new Set(agentSpans.map((s) => s.spanContext().traceId)) + expect(traceIds.size).toBe(2) + + // Each trace has its own complete hierarchy + for (const agentSpan of agentSpans) { + const traceId = agentSpan.spanContext().traceId + const traceSpans = spans.filter((s) => s.spanContext().traceId === traceId) + const traceCycles = findSpans(traceSpans, CYCLE_SPAN_NAME) + const traceModels = findSpans(traceSpans, MODEL_SPAN_NAME) + + expect(traceCycles).toHaveLength(1) + expect(traceModels).toHaveLength(1) + assertParentChild(agentSpan, traceCycles[0]!) + assertParentChild(traceCycles[0]!, traceModels[0]!) + } + }) + + it('creates isolated traces for same-named concurrent agents', async () => { + const model1 = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Response A' }) + const model2 = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Response B' }) + + const agent1 = new Agent({ model: model1, printer: false, name: 'shared-name' }) + const agent2 = new Agent({ model: model2, printer: false, name: 'shared-name' }) + + await Promise.all([agent1.invoke('Hello 1'), agent2.invoke('Hello 2')]) + + const spans = await flush() + const agentSpans = findSpans(spans, AGENT_SPAN_PREFIX) + + expect(agentSpans).toHaveLength(2) + expect(agentSpans.every((s) => s.name === 'invoke_agent shared-name')).toBe(true) + + // Same name but distinct traces + const traceIds = new Set(agentSpans.map((s) => s.spanContext().traceId)) + expect(traceIds.size).toBe(2) + + // Each trace still has a complete hierarchy with correct output + const expectedResponses = new Set(['Response A', 'Response B']) + for (const agentSpan of agentSpans) { + const traceId = agentSpan.spanContext().traceId + const traceSpans = spans.filter((s) => s.spanContext().traceId === traceId) + const traceCycles = findSpans(traceSpans, CYCLE_SPAN_NAME) + const traceModels = findSpans(traceSpans, MODEL_SPAN_NAME) + + expect(traceCycles).toHaveLength(1) + expect(traceModels).toHaveLength(1) + assertParentChild(agentSpan, traceCycles[0]!) + assertParentChild(traceCycles[0]!, traceModels[0]!) + + const choiceEvent = agentSpan.events.find((e) => e.name === 'gen_ai.choice') + expect(choiceEvent).toBeDefined() + const message = choiceEvent!.attributes!['message'] as string + expect(expectedResponses.has(message)).toBe(true) + expectedResponses.delete(message) + } + + // Both responses were seen + expect(expectedResponses.size).toBe(0) + }) + }) + + describe('getTracer', () => { + it('returns a tracer that produces spans captured by the registered provider', async () => { + const tracer = getTracer() + const span = tracer.startSpan('custom-operation') + span.setAttribute('custom.key', 'custom-value') + span.end() + + const spans = await flush() + const customSpans = spans.filter((s) => s.name === 'custom-operation') + + expect(customSpans).toHaveLength(1) + expect(attr(customSpans[0]!, 'custom.key')).toBe('custom-value') + }) + + // The OTel global tracer provider can only be set once per process via register(). + // Subsequent register() calls are no-ops and emit a warning. All spans always + // land in the first registered provider. + + it('ignores later register() calls — spans stay in the first registered provider', async () => { + const userExporter = new InMemorySpanExporter() + const userProvider = new NodeTracerProvider({ + spanProcessors: [new SimpleSpanProcessor(userExporter)], + }) + trace.setGlobalTracerProvider(userProvider) // no-op: global provider already set in beforeAll + + const tracer = getTracer() + const span = tracer.startSpan('user-provider-span') + span.setAttribute('source', 'custom-provider') + span.end() + + // Span lands in the original shared provider, not the user's + const spans = await flush() + const sharedSpan = spans.find((s) => s.name === 'user-provider-span') + expect(sharedSpan).toBeDefined() + expect(sharedSpan!.attributes['source']).toBe('custom-provider') + + // The user's exporter never receives the span + await userProvider.forceFlush() + const userSpans = userExporter.getFinishedSpans() + expect(userSpans.find((s) => s.name === 'user-provider-span')).toBeUndefined() + }) + + it('all spans land in the first registered provider even when multiple providers call register()', async () => { + const exporterA = new InMemorySpanExporter() + const providerA = new NodeTracerProvider({ + spanProcessors: [new SimpleSpanProcessor(exporterA)], + }) + trace.setGlobalTracerProvider(providerA) // no-op + + const exporterB = new InMemorySpanExporter() + const providerB = new NodeTracerProvider({ + spanProcessors: [new SimpleSpanProcessor(exporterB)], + }) + trace.setGlobalTracerProvider(providerB) // no-op + + const tracer = getTracer() + const span = tracer.startSpan('multi-register-span') + span.end() + + // Span lands in the original shared provider + const spans = await flush() + expect(spans.find((s) => s.name === 'multi-register-span')).toBeDefined() + + // Neither late provider receives the span + await providerA.forceFlush() + await providerB.forceFlush() + expect(exporterA.getFinishedSpans().find((s) => s.name === 'multi-register-span')).toBeUndefined() + expect(exporterB.getFinishedSpans().find((s) => s.name === 'multi-register-span')).toBeUndefined() + }) + + it('creates custom spans that nest under agent spans via context propagation', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hi' }) + const agent = new Agent({ model, printer: false, name: 'gettracer-nest-agent' }) + + await agent.invoke('Hello') + + const allSpans = await flush() + const agentReadableSpan = findSpans(allSpans, AGENT_SPAN_PREFIX)[0]! + + // Wrap the ReadableSpan's context into a live span reference for context propagation + const agentSpanRef = trace.wrapSpanContext(agentReadableSpan.spanContext()) + + // Create a custom span parented to the agent span via context + const tracer = getTracer() + context.with(trace.setSpan(context.active(), agentSpanRef), () => { + const childSpan = tracer.startSpan('custom-child') + childSpan.end() + }) + + const spansAfter = await flush() + const childSpan = spansAfter.find((s) => s.name === 'custom-child')! + + expect(childSpan).toBeDefined() + expect(childSpan.spanContext().traceId).toBe(agentReadableSpan.spanContext().traceId) + expect(childSpan.parentSpanContext?.spanId).toBe(agentReadableSpan.spanContext().spanId) + }) + }) +}) + +describe.sequential('Metrics Integration', () => { + let metricExporter: InMemoryMetricExporter + let metricReader: PeriodicExportingMetricReader + let meterProvider: MeterProvider + + const calculatorTool = tool({ + name: 'calculator', + description: 'Add two numbers', + inputSchema: z.object({ a: z.number(), b: z.number() }), + callback: ({ a, b }) => `${a + b}`, + }) + + beforeAll(() => { + metricExporter = new InMemoryMetricExporter(AggregationTemporality.CUMULATIVE) + metricReader = new PeriodicExportingMetricReader({ + exporter: metricExporter, + exportIntervalMillis: 100, + }) + meterProvider = new MeterProvider({ + readers: [metricReader], + }) + otelMetrics.setGlobalMeterProvider(meterProvider) + }) + + beforeEach(() => { + metricExporter.reset() + }) + + afterAll(async () => { + await meterProvider.forceFlush() + await meterProvider.shutdown() + }) + + async function collectMetrics(): Promise> { + await meterProvider.forceFlush() + return [...metricExporter.getMetrics()] + } + + it('emits cycle count metrics during agent invocation', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false, name: 'metrics-cycle-agent' }) + + await agent.invoke('Hi') + + const metrics = await collectMetrics() + const cycleCount = findMetricValue(metrics, 'gen_ai.agent.cycle.count') + + expect(cycleCount).toBeGreaterThanOrEqual(1) + }) + + it('emits invocation count metrics', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, printer: false, name: 'metrics-invocation-agent' }) + + await agent.invoke('Hi') + + const metrics = await collectMetrics() + const invocationCount = findMetricValue(metrics, 'gen_ai.agent.invocation.count') + + expect(invocationCount).toBeGreaterThanOrEqual(1) + }) + + it('emits token usage metrics', async () => { + const model = new MockMessageModel().addTurn( + { type: 'textBlock', text: 'Hello back' }, + { usage: { inputTokens: 50, outputTokens: 25, totalTokens: 75 } } + ) + + const agent = new Agent({ model, printer: false, name: 'metrics-token-agent' }) + + await agent.invoke('Hello') + + const metrics = await collectMetrics() + const inputTokens = findMetricValue(metrics, 'gen_ai.agent.tokens.input') + const outputTokens = findMetricValue(metrics, 'gen_ai.agent.tokens.output') + + expect(inputTokens).toBeGreaterThanOrEqual(50) + expect(outputTokens).toBeGreaterThanOrEqual(25) + }) + + it('emits tool call metrics when tools are used', async () => { + const model = new MockMessageModel() + .addTurn( + { type: 'toolUseBlock', name: 'calculator', toolUseId: 'tool-1', input: { a: 1, b: 2 } }, + { usage: { inputTokens: 30, outputTokens: 10, totalTokens: 40 } } + ) + .addTurn( + { type: 'textBlock', text: 'The answer is 3' }, + { usage: { inputTokens: 40, outputTokens: 15, totalTokens: 55 } } + ) + + const agent = new Agent({ model, printer: false, name: 'metrics-tool-agent', tools: [calculatorTool] }) + + await agent.invoke('Add 1 and 2') + + const metrics = await collectMetrics() + const toolCallCount = findMetricValue(metrics, 'gen_ai.agent.tool.call.count') + + expect(toolCallCount).toBeGreaterThanOrEqual(1) + }) + + it('emits cycle duration histogram', async () => { + const model = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Done' }) + const agent = new Agent({ model, printer: false, name: 'metrics-duration-agent' }) + + await agent.invoke('Hello') + + const metrics = await collectMetrics() + const durationValue = findMetricValue(metrics, 'gen_ai.agent.cycle.duration') + + expect(durationValue).toBeDefined() + }) + + it('emits metrics across multiple invocations cumulatively', async () => { + const model1 = new MockMessageModel().addTurn({ type: 'textBlock', text: 'First' }) + const model2 = new MockMessageModel().addTurn({ type: 'textBlock', text: 'Second' }) + + const agent1 = new Agent({ model: model1, printer: false, name: 'metrics-multi-1' }) + const agent2 = new Agent({ model: model2, printer: false, name: 'metrics-multi-2' }) + + await agent1.invoke('Hello') + await agent2.invoke('World') + + const metrics = await collectMetrics() + const cycleCount = findMetricValue(metrics, 'gen_ai.agent.cycle.count') + + // At least 2 cycles (one per invocation) + expect(cycleCount).toBeGreaterThanOrEqual(2) + }) + + it('getMeter returns a meter that records real metrics', async () => { + const meter = getMeter() + const counter = meter.createCounter('test.custom.counter') + counter.add(7) + + const metrics = await collectMetrics() + expect(findMetricValue(metrics, 'test.custom.counter')).toBe(7) + }) +}) diff --git a/strands-ts/test/integ/tools/bash.test.node.ts b/strands-ts/test/integ/tools/bash.test.node.ts new file mode 100644 index 0000000000..bec51ea208 --- /dev/null +++ b/strands-ts/test/integ/tools/bash.test.node.ts @@ -0,0 +1,49 @@ +import { describe, expect, it } from 'vitest' +import { Agent } from '$/sdk/index.js' +import { bash } from '$/sdk/vended-tools/bash/index.js' +import { getMessageText } from '../__fixtures__/model-test-helpers.js' +import { bedrock } from '../__fixtures__/model-providers.js' + +describe.skipIf(bedrock.skip || process.platform === 'win32')('Bash Tool Integration', () => { + // Shared agent configuration for all tests + const createAgent = () => + new Agent({ + model: bedrock.createModel({ + region: 'us-east-1', + }), + tools: [bash], + }) + + describe('basic execution', () => { + it('captures stdout streams correctly', async () => { + const agent = createAgent() + const stdoutResult = await agent.invoke('Use bash to echo "Hello from bash"') + expect(getMessageText(stdoutResult.lastMessage)).toContain('Hello from bash') + }) + + it('captures stderr streams correctly', async () => { + const agent = createAgent() + const stderrResult = await agent.invoke('Use bash to run: echo "error" >&2') + expect(getMessageText(stderrResult.lastMessage)).toContain('error') + }) + + it('handles complex command patterns', async () => { + const agent = createAgent() + + // Test command sequencing + const seqResult = await agent.invoke('Use bash to: create a variable TEST=hello, then echo it') + expect(getMessageText(seqResult.lastMessage).toLowerCase()).toContain('hello') + }) + }) + + describe('error handling', () => { + it('handles command errors gracefully', async () => { + const agent = createAgent() + const result = await agent.invoke('Use bash to run: nonexistent_command_xyz') + + // Should indicate command not found or error + const lastMessage = getMessageText(result.lastMessage).toLowerCase() + expect(lastMessage).toMatch(/not found|error|command/) + }) + }) +}) diff --git a/strands-ts/test/integ/tools/file-editor.test.node.ts b/strands-ts/test/integ/tools/file-editor.test.node.ts new file mode 100644 index 0000000000..ed6d2cc681 --- /dev/null +++ b/strands-ts/test/integ/tools/file-editor.test.node.ts @@ -0,0 +1,164 @@ +import { afterEach, beforeEach, describe, expect, it } from 'vitest' +import { Agent } from '$/sdk/index.js' +import { fileEditor } from '$/sdk/vended-tools/file-editor/index.js' +import { collectGenerator } from '$/sdk/__fixtures__/model-test-helpers.js' +import { promises as fs } from 'fs' +import * as path from 'path' +import { tmpdir } from 'os' +import { bedrock } from '../__fixtures__/model-providers.js' + +describe.skipIf(bedrock.skip)('FileEditor Tool Integration', () => { + let testDir: string + + // Shared agent configuration for all tests + const createAgent = () => + new Agent({ + model: bedrock.createModel({ + region: 'us-east-1', + }), + tools: [fileEditor], + }) + + beforeEach(async () => { + // Create a temporary test directory + testDir = path.join(tmpdir(), `file-editor-integ-test-${Date.now()}-${Math.random().toString(36).slice(2)}`) + await fs.mkdir(testDir, { recursive: true }) + }) + + afterEach(async () => { + // Clean up test directory + try { + await fs.rm(testDir, { recursive: true, force: true }) + } catch (error) { + console.error('Failed to clean up test directory', testDir) + console.error(error) + } + }) + + it('should create and view a file via prompt', async () => { + const agent = createAgent() + const testFile = path.join(testDir, 'test.txt') + + // Create a file + await agent.invoke(`Create a file at ${testFile} with content "Hello World"`) + + // Verify file was created on disk + const fileContent = await fs.readFile(testFile, 'utf-8') + expect(fileContent).toBe('Hello World') + + // View the file + const { items: events } = await collectGenerator(agent.stream(`View the file at ${testFile}`)) + + // The agent should have received the file content + const textBlocks = events.filter((e: any) => e.type === 'contentBlockEvent' && e.contentBlock.type === 'textBlock') + expect(textBlocks.length).toBeGreaterThan(0) + }, 60000) + + it('should edit a file using str_replace', async () => { + const agent = createAgent() + const testFile = path.join(testDir, 'edit-test.txt') + + // Create initial file + await agent.invoke(`Create a file at ${testFile} with content "Hello OLD World"`) + + // Replace text + await agent.invoke(`In the file ${testFile}, replace "OLD" with "NEW"`) + + // Verify the replacement on disk + const fileContent = await fs.readFile(testFile, 'utf-8') + expect(fileContent).toBe('Hello NEW World') + }, 60000) + + it('should insert text at specific lines', async () => { + const agent = createAgent() + const testFile = path.join(testDir, 'insert-test.txt') + + // Create file with multiple lines + const initialContent = 'Line 1\nLine 2\nLine 3' + await agent.invoke(`Create a file at ${testFile} with content "${initialContent}"`) + + // Insert text at line 2 + await agent.invoke(`In the file ${testFile}, insert "Inserted Line" at line 2`) + + // Verify the insertion on disk + const fileContent = await fs.readFile(testFile, 'utf-8') + expect(fileContent).toBe('Line 1\nLine 2\nInserted Line\nLine 3') + }, 60000) + + it('should handle errors gracefully', async () => { + const agent = createAgent() + const nonExistentFile = path.join(testDir, 'does-not-exist.txt') + + // Try to view non-existent file + const { items: events } = await collectGenerator(agent.stream(`View the file at ${nonExistentFile}`)) + + // The agent should handle the error and provide a reasonable response + const toolResults = events.filter((e: any) => e.type === 'toolResultEvent') + expect(toolResults.length).toBeGreaterThan(0) + + // The model should have handled the error gracefully + const textBlocks = events.filter((e: any) => e.type === 'contentBlockEvent' && e.contentBlock.type === 'textBlock') + expect(textBlocks.length).toBeGreaterThan(0) + }, 60000) + + it('should view directory contents', async () => { + const agent = createAgent() + + // Create some files in the test directory + await fs.writeFile(path.join(testDir, 'file1.txt'), 'content1', 'utf-8') + await fs.writeFile(path.join(testDir, 'file2.txt'), 'content2', 'utf-8') + await fs.mkdir(path.join(testDir, 'subdir'), { recursive: true }) + await fs.writeFile(path.join(testDir, 'subdir', 'file3.txt'), 'content3', 'utf-8') + + // View the directory + const { items: events } = await collectGenerator(agent.stream(`List the files in directory ${testDir}`)) + + // The agent should have received the directory listing + const textBlocks = events.filter((e: any) => e.type === 'contentBlockEvent' && e.contentBlock.type === 'textBlock') + expect(textBlocks.length).toBeGreaterThan(0) + }, 60000) + + it('should handle multi-line file content', async () => { + const agent = createAgent() + const testFile = path.join(testDir, 'multiline-test.txt') + + // Create file with multiple lines + const multilineContent = `Line 1 +Line 2 +Line 3 +Line 4` + + await agent.invoke(`Create a file at ${testFile} with this content: +${multilineContent}`) + + // Verify file was created correctly + const fileContent = await fs.readFile(testFile, 'utf-8') + expect(fileContent).toContain('Line 1') + expect(fileContent).toContain('Line 4') + + // Replace multi-line content + await agent.invoke(`In the file ${testFile}, replace "Line 2 +Line 3" with "Replaced Lines"`) + + // Verify replacement + const updatedContent = await fs.readFile(testFile, 'utf-8') + expect(updatedContent).toContain('Replaced Lines') + expect(updatedContent).not.toContain('Line 2') + }, 60000) + + it('should handle view with line ranges', async () => { + const agent = createAgent() + const testFile = path.join(testDir, 'range-test.txt') + + // Create file with multiple lines + const content = 'Line 1\nLine 2\nLine 3\nLine 4\nLine 5' + await agent.invoke(`Create a file at ${testFile} with content "${content}"`) + + // View specific line range + const { items: events } = await collectGenerator(agent.stream(`View lines 2 to 4 of file ${testFile}`)) + + // The agent should have used view_range parameter + const toolResults = events.filter((e: any) => e.type === 'toolResultEvent') + expect(toolResults.length).toBeGreaterThan(0) + }, 60000) +}) diff --git a/strands-ts/test/integ/tools/http-request.test.ts b/strands-ts/test/integ/tools/http-request.test.ts new file mode 100644 index 0000000000..af53d0fb65 --- /dev/null +++ b/strands-ts/test/integ/tools/http-request.test.ts @@ -0,0 +1,23 @@ +import { describe, expect, it } from 'vitest' +import { httpRequest } from '@strands-agents/sdk/vended-tools/http-request' +import { Agent } from '@strands-agents/sdk' +import { bedrock } from '../__fixtures__/model-providers.js' + +describe.skipIf(bedrock.skip)('httpRequest tool (integration)', () => { + it('agent uses http_request tool to fetch weather from Open-Meteo', async () => { + const agent = new Agent({ + model: bedrock.createModel({ maxTokens: 500 }), + tools: [httpRequest], + printer: false, + }) + + const result = await agent.invoke('Call Open-Meteo to get the weather in NYC') + + // Verify agent made a request and returned weather information + expect(result.toString().toLowerCase()).toMatch(/weather|temperature|forecast|nyc|new york/) + + // Verify the result structure + expect(result.stopReason).toBe('endTurn') + expect(result.lastMessage.role).toBe('assistant') + }) +}) diff --git a/strands-ts/test/integ/tools/notebook.test.ts b/strands-ts/test/integ/tools/notebook.test.ts new file mode 100644 index 0000000000..fbb6cbd96c --- /dev/null +++ b/strands-ts/test/integ/tools/notebook.test.ts @@ -0,0 +1,103 @@ +import { describe, expect, it } from 'vitest' +import type { AgentResult, AgentStreamEvent } from '$/sdk/index.js' +import { Agent } from '$/sdk/index.js' +import { notebook } from '$/sdk/vended-tools/notebook/index.js' +import { collectGenerator } from '$/sdk/__fixtures__/model-test-helpers.js' +import { bedrock } from '../__fixtures__/model-providers.js' + +describe.skipIf(bedrock.skip)('Notebook Tool Integration', () => { + // Shared agent configuration for all tests + const agentParams = { + model: bedrock.createModel({ + region: 'us-east-1', + }), + tools: [notebook], + } + + it('should persist notebook state across tool invocations', async () => { + // Create agent with notebook tool + const agent = new Agent(agentParams) + + // Step 1: Create a notebook + const { items: _events1 } = await collectGenerator( + agent.stream('Create a notebook called "test" with content "# Test Notebook"') + ) + + // Verify notebook was created + const notebooks1 = agent.appState.get('notebooks') as any + expect(notebooks1).toBeTruthy() + expect(notebooks1).toHaveProperty('test') + expect(notebooks1.test).toContain('# Test Notebook') + + // Step 2: Add content to the notebook + const { items: _events2 } = await collectGenerator(agent.stream('Add "- First item" to the test notebook')) + + // Verify content was added + const notebooks2 = agent.appState.get('notebooks') as any + expect(notebooks2.test).toContain('- First item') + + // Step 3: Read the notebook + const { items: events3 } = await collectGenerator( + agent.stream('Read the test notebook') + ) + + // Find the last content block complete event with a text block to get agent's response + const textBlocks = events3.filter((e) => e.type === 'contentBlockEvent' && e.contentBlock.type === 'textBlock') + expect(textBlocks.length).toBeGreaterThan(0) + + // The notebook should still contain both pieces of content + const notebooks3 = agent.appState.get('notebooks') as any + expect(notebooks3.test).toContain('# Test Notebook') + expect(notebooks3.test).toContain('- First item') + }, 30000) // 30 second timeout for network calls + + it('should restore state across agent instances', async () => { + // Create first agent and add content + const agent1 = new Agent(agentParams) + + // Create notebook with first agent + await collectGenerator(agent1.stream('Create a notebook called "persist" with "Persistent content"')) + + // Verify notebook was created + const notebooks1 = agent1.appState.get('notebooks') as any + expect(notebooks1).toBeTruthy() + expect(notebooks1.persist).toContain('Persistent content') + + // Save state + const savedState = agent1.appState.getAll() + + // Create second agent with restored state + const agent2 = new Agent({ + ...agentParams, + appState: savedState, // Pass state in constructor + }) + + // Verify notebooks were restored + const notebooks2 = agent2.appState.get('notebooks') as any + expect(notebooks2).toBeTruthy() + expect(notebooks2.persist).toContain('Persistent content') + + // Use the restored notebook - just read it + await collectGenerator(agent2.stream('Read the persist notebook')) + + // Verify content still exists + const notebooks3 = agent2.appState.get('notebooks') as any + expect(notebooks3.persist).toContain('Persistent content') + }, 30000) + + it('should handle errors gracefully', async () => { + const agent = new Agent(agentParams) + + // Try to read non-existent notebook + const { items: events } = await collectGenerator(agent.stream('Read a notebook called "nonexistent"')) + + // The agent should handle the error and provide a reasonable response + // Check that we got tool result events (indicating tool was called) + const toolResults = events.filter((e) => e.type === 'toolResultEvent') + expect(toolResults.length).toBeGreaterThan(0) + + // The model should have handled the error gracefully + const textBlocks = events.filter((e) => e.type === 'contentBlockEvent' && e.contentBlock.type === 'textBlock') + expect(textBlocks.length).toBeGreaterThan(0) + }, 30000) +}) diff --git a/strands-ts/test/integ/tsconfig.json b/strands-ts/test/integ/tsconfig.json new file mode 100644 index 0000000000..f4fdc6a8ca --- /dev/null +++ b/strands-ts/test/integ/tsconfig.json @@ -0,0 +1,9 @@ +{ + "extends": "../../tsconfig.base.json", + "compilerOptions": { + "paths": { + "$/sdk/*": ["../../src/*"] + } + }, + "references": [{ "path": "../../src/tsconfig.json" }] +} diff --git a/strands-ts/test/integ/vended-interventions/steering/steering.test.node.ts b/strands-ts/test/integ/vended-interventions/steering/steering.test.node.ts new file mode 100644 index 0000000000..4dc0735e01 --- /dev/null +++ b/strands-ts/test/integ/vended-interventions/steering/steering.test.node.ts @@ -0,0 +1,99 @@ +import { describe, expect, it } from 'vitest' +import { z } from 'zod' +import { Agent, tool } from '$/sdk/index.js' +import { guide, proceed, type Guide, type Proceed } from '$/sdk/interventions/actions.js' +import { SteeringHandler, ToolLedgerProvider } from '$/sdk/vended-interventions/steering/index.js' +import type { BeforeToolCallEvent } from '$/sdk/hooks/events.js' +import { bedrock } from '../../__fixtures__/model-providers.js' + +const sendEmail = tool({ + name: 'send_email', + description: 'Send an email to a recipient', + inputSchema: z.object({ recipient: z.string(), message: z.string() }), + callback: async ({ recipient, message }) => `Email sent to ${recipient}: ${message}`, +}) + +const sendNotification = tool({ + name: 'send_notification', + description: 'Send a notification to a recipient', + inputSchema: z.object({ recipient: z.string(), message: z.string() }), + callback: async ({ recipient, message }) => `Notification sent to ${recipient}: ${message}`, +}) + +describe.skipIf(bedrock.skip)('Steering integration', () => { + const createModel = () => bedrock.createModel({ maxTokens: 1024 }) + + it('redirects send_email to send_notification via Guide', async () => { + class RedirectEmailHandler extends SteeringHandler { + override readonly name = 'redirect-email' + override async beforeToolCall(event: BeforeToolCallEvent): Promise { + if (event.toolUse.name === 'send_email') { + return guide('Use send_notification instead of send_email for better delivery.') + } + return proceed() + } + } + + const agent = new Agent({ + model: createModel(), + tools: [sendEmail, sendNotification], + interventions: [new RedirectEmailHandler()], + systemPrompt: + 'You are a helpful assistant. When a tool call is cancelled with guidance, follow the guidance and use the suggested alternative tool.', + printer: false, + }) + + const result = await agent.invoke('Send an email to john@example.com saying hello') + + const toolMetrics = result.metrics?.toolMetrics ?? {} + + if (toolMetrics.send_email) { + expect(toolMetrics.send_email.callCount).toBeGreaterThanOrEqual(1) + expect(toolMetrics.send_email.successCount).toBe(0) + } + + expect(toolMetrics.send_notification).toBeDefined() + expect(toolMetrics.send_notification!.callCount).toBeGreaterThanOrEqual(1) + expect(toolMetrics.send_notification!.successCount).toBeGreaterThanOrEqual(1) + }) + + it('ToolLedgerProvider captures tool calls during a real invocation', async () => { + const ledger = new ToolLedgerProvider() + + class LedgerCheckingHandler extends SteeringHandler { + override readonly name = 'ledger-check' + + override async beforeToolCall(event: BeforeToolCallEvent): Promise { + const calls = (ledger.context.calls ?? []) as Array> + const current = calls.find((c) => c.name === event.toolUse.name) + expect(current).toBeDefined() + expect(current?.args).toEqual(event.toolUse.input) + expect(current?.status).toBe('pending') + return proceed() + } + } + + const handler = new LedgerCheckingHandler({ contextProviders: [ledger] }) + + const agent = new Agent({ + model: createModel(), + tools: [sendNotification], + interventions: [handler], + printer: false, + }) + + await agent.invoke('Send a notification to alice saying test message') + + const calls = (ledger.context.calls ?? []) as Array> + expect(calls.length).toBeGreaterThanOrEqual(1) + + const last = calls[calls.length - 1]! + expect(last.name).toBe('send_notification') + const args = last.args as Record + expect(args.recipient).toBe('alice') + expect(args.message).toContain('test message') + expect(last.status).toBe('success') + expect(last.endTime).toBeTypeOf('string') + expect(last.error).toBeNull() + }) +}) diff --git a/strands-ts/test/integ/vitest.d.ts b/strands-ts/test/integ/vitest.d.ts new file mode 100644 index 0000000000..0a5988de3e --- /dev/null +++ b/strands-ts/test/integ/vitest.d.ts @@ -0,0 +1,29 @@ +import 'vitest' +import type { AwsCredentialIdentity } from '@aws-sdk/types' + +declare module 'vitest' { + export interface ProvidedContext { + isCI: boolean + isBrowser: boolean + ['provider-openai']: { + shouldSkip: boolean + apiKey: string | undefined + } + ['provider-bedrock']: { + shouldSkip: boolean + credentials: AwsCredentialIdentity | undefined + } + ['provider-anthropic']: { + shouldSkip: boolean + apiKey: string | undefined + } + ['provider-gemini']: { + shouldSkip: boolean + apiKey: string | undefined + } + ['a2a-server']: { + shouldSkip: boolean + url: string | undefined + } + } +} diff --git a/strands-ts/test/packages/README.md b/strands-ts/test/packages/README.md new file mode 100644 index 0000000000..dd82ea8f87 --- /dev/null +++ b/strands-ts/test/packages/README.md @@ -0,0 +1,33 @@ +# Package Import Tests + +This directory contains verification tests to ensure `@strands-agents/sdk` can be imported correctly. There are two flavors, catching different classes of packaging bug: + +- **`esm-module/` and `cjs-module/`** — fast local tests. Install the SDK via `file:../../..` and exercise ESM `import` + CommonJS `require`. Run by `npm run test:package`. These resolve through the monorepo, so they share the root `node_modules` and cannot detect missing-optional-peer regressions. +- **`npm-pack/`** — CI-only smoke test (`.github/workflows/test-package-pack.yml`). Runs `npm pack` and installs the tarball in a tempdir outside the monorepo, mirroring an end-user install. Catches the RC.0 class of bug where the main entry re-exports a symbol from an optional peer dependency. + +## Running the Tests + +From the root of the project: + +```bash +npm run test:package +``` + +This command builds and installs the SDK locally, then runs both ESM and CJS import tests. The tarball test is not wired into this script — see `.github/workflows/test-package-pack.yml` for its invocation. + +## Test Structure + +``` +test/packages/ +├── esm-module/ # ES Module import test (file: install) +│ ├── esm.js # Uses `import { ... } from '@strands-agents/sdk'` +│ └── package.json +├── cjs-module/ # CommonJS import test (file: install) +│ ├── cjs.js # Uses `require('@strands-agents/sdk')` +│ └── package.json +├── npm-pack/ # Packed-tarball install smoke test (CI-only) +│ ├── verify.ts # Type-checked consumer script +│ ├── package.json +│ └── tsconfig.json +└── README.md +``` diff --git a/strands-ts/test/packages/cjs-module/cjs.js b/strands-ts/test/packages/cjs-module/cjs.js new file mode 100644 index 0000000000..96857d8a53 --- /dev/null +++ b/strands-ts/test/packages/cjs-module/cjs.js @@ -0,0 +1,105 @@ +/** + * Verification script to ensure the built package can be imported from a + * pure-CJS Node project via dynamic import(). The SDK itself is ESM-only; + * CJS consumers interop by using await import(). + */ + +async function main() { + const { Agent, BedrockModel, tool, Tool } = await import('@strands-agents/sdk') + + const { notebook } = await import('@strands-agents/sdk/vended-tools/notebook') + const { fileEditor } = await import('@strands-agents/sdk/vended-tools/file-editor') + const { httpRequest } = await import('@strands-agents/sdk/vended-tools/http-request') + const { bash } = await import('@strands-agents/sdk/vended-tools/bash') + + const { + bash: barrelBash, + fileEditor: barrelFileEditor, + httpRequest: barrelHttpRequest, + notebook: barrelNotebook, + } = await import('@strands-agents/sdk/vended-tools') + + const { + AgentSkills, + ContextOffloader, + InMemoryStorage, + } = await import('@strands-agents/sdk/vended-plugins') + + const { BedrockModel: BedrockFromSubpath } = await import('@strands-agents/sdk/models/bedrock') + const { OpenAIModel } = await import('@strands-agents/sdk/models/openai') + const { AnthropicModel } = await import('@strands-agents/sdk/models/anthropic') + const { GoogleModel } = await import('@strands-agents/sdk/models/google') + + const { z } = await import('zod') + + console.log('✓ Import from main entry point successful') + + const model = new BedrockModel({ region: 'us-west-2' }) + console.log('✓ BedrockModel instantiation successful') + + const config = model.getConfig() + if (!config) { + throw new Error('BedrockModel config is invalid') + } + console.log('✓ BedrockModel configuration retrieval successful') + + const example_tool = tool({ + name: 'get_weather', + description: 'Get the current weather for a specific location.', + inputSchema: z.object({ + location: z.string().describe('The city and state, e.g., San Francisco, CA'), + }), + callback: (input) => { + console.log(`\n[WeatherTool] Getting weather for ${input.location}...`) + return `The weather in ${input.location} is 72°F and sunny.` + }, + }) + console.log('✓ Tool created successful') + + const response = await example_tool.invoke({ location: 'New York' }) + if (response !== `The weather in New York is 72°F and sunny.`) { + throw new Error('Tool returned invalid response') + } + + const agent = new Agent({ + tools: [example_tool], + }) + + if (agent.tools.length == 0) { + throw new Error('Tool was not correctly added to the agent') + } + + const tools = { notebook, fileEditor, httpRequest, bash } + for (const tool of Object.values(tools)) { + if (!(tool instanceof Tool)) { + throw new Error(`Tool ${tool.name} isn't an instance of a tool`) + } + } + + if (BedrockFromSubpath !== BedrockModel) { + throw new Error('BedrockModel from subpath should match main export') + } + console.log('✓ Model subpath exports verified') + + // Verify barrel exports match individual subpath exports + if (barrelBash !== bash || barrelFileEditor !== fileEditor || barrelHttpRequest !== httpRequest || barrelNotebook !== notebook) { + throw new Error('Barrel vended-tools exports do not match individual subpath exports') + } + console.log('✓ Barrel vended-tools exports verified') + + // Verify barrel vended-plugins exports are constructible + if (typeof AgentSkills !== 'function' || typeof ContextOffloader !== 'function' || typeof InMemoryStorage !== 'function') { + throw new Error('Barrel vended-plugins exports are not constructible') + } + console.log('✓ Barrel vended-plugins exports verified') + + // Reference remaining imports so static analysis doesn't flag them unused. + void OpenAIModel + void AnthropicModel + void GoogleModel +} + +main().catch((error) => { + console.error(error) + process.exit(1) +}) diff --git a/strands-ts/test/packages/cjs-module/package.json b/strands-ts/test/packages/cjs-module/package.json new file mode 100644 index 0000000000..e2559ecd5f --- /dev/null +++ b/strands-ts/test/packages/cjs-module/package.json @@ -0,0 +1,10 @@ +{ + "type": "commonjs", + "name": "test-package", + "version": "1.0.0", + "private": true, + "description": "Test package to verify SDK works with CSJ", + "dependencies": { + "@strands-agents/sdk": "file:../../.." + } +} \ No newline at end of file diff --git a/strands-ts/test/packages/esm-module/esm.js b/strands-ts/test/packages/esm-module/esm.js new file mode 100644 index 0000000000..d019769f2f --- /dev/null +++ b/strands-ts/test/packages/esm-module/esm.js @@ -0,0 +1,137 @@ +/** + * Verification script to ensure the built package can be imported without a bundler. + * This script runs in a pure Node.js ES module environment. + */ + +import { Agent, BedrockModel, tool, Tool } from '@strands-agents/sdk' + +import { notebook } from '@strands-agents/sdk/vended-tools/notebook' +import { fileEditor } from '@strands-agents/sdk/vended-tools/file-editor' +import { httpRequest } from '@strands-agents/sdk/vended-tools/http-request' +import { bash } from '@strands-agents/sdk/vended-tools/bash' + +import { + bash as barrelBash, + fileEditor as barrelFileEditor, + httpRequest as barrelHttpRequest, + notebook as barrelNotebook, +} from '@strands-agents/sdk/vended-tools' + +import { + AgentSkills, + ContextOffloader, + InMemoryStorage, +} from '@strands-agents/sdk/vended-plugins' + +// Verify model subpath exports +import { BedrockModel as BedrockFromSubpath } from '@strands-agents/sdk/models/bedrock' +import { OpenAIModel } from '@strands-agents/sdk/models/openai' +import { AnthropicModel } from '@strands-agents/sdk/models/anthropic' +import { GoogleModel } from '@strands-agents/sdk/models/google' + +import { z } from 'zod' + +console.log('✓ Import from main entry point successful') + +// Verify BedrockModel can be instantiated +const model = new BedrockModel({ region: 'us-west-2' }) +console.log('✓ BedrockModel instantiation successful') + +// Verify basic functionality +const config = model.getConfig() +if (!config) { + throw new Error('BedrockModel config is invalid') +} +console.log('✓ BedrockModel configuration retrieval successful') + +// Define a tool +const example_tool = tool({ + name: 'get_weather', + description: 'Get the current weather for a specific location.', + inputSchema: z.object({ + location: z.string().describe('The city and state, e.g., San Francisco, CA'), + }), + callback: (input) => { + console.log(`\n[WeatherTool] Getting weather for ${input.location}...`) + + const fakeWeatherData = { + temperature: '72°F', + conditions: 'sunny', + } + + return `The weather in ${input.location} is ${fakeWeatherData.temperature} and ${fakeWeatherData.conditions}.` + }, +}) +console.log('✓ Tool created successful') + +// Verify tool can be called +const response = await example_tool.invoke({ location: 'New York' }) +if (response !== `The weather in New York is 72°F and sunny.`) { + throw new Error('Tool returned invalid response') +} + +// Verify Agent can be instantiated +const agent = new Agent({ + tools: [example_tool], +}) + +if (agent.tools.length == 0) { + throw new Error('Tool was not correctly added to the agent') +} + +async function validateScratchpad() { + let context = { agent: agent } + notebook.invoke( + { + mode: 'create', + name: 'scratchpad', + newStr: 'Content', + }, + context + ) + + const result = await notebook.invoke( + { + mode: 'read', + name: 'scratchpad', + }, + context + ) + + if (result !== 'Content') { + throw new Error(`Tool returned invalid response: ${result}`) + } + + console.log('Notebook created successful') +} + +const tools = { + notebook, + fileEditor, + httpRequest, + bash, +} + +for (const tool of Object.values(tools)) { + if (!(tool instanceof Tool)) { + throw new Error(`Tool ${tool.name} isn't an instance of a tool`) + } +} + +// Verify model subpath exports resolve correctly +if (BedrockFromSubpath !== BedrockModel) { + throw new Error('BedrockModel from subpath should match main export') +} +console.log('✓ Model subpath exports verified') + +// Verify barrel exports match individual subpath exports +if (barrelBash !== bash || barrelFileEditor !== fileEditor || barrelHttpRequest !== httpRequest || barrelNotebook !== notebook) { + throw new Error('Barrel vended-tools exports do not match individual subpath exports') +} +console.log('✓ Barrel vended-tools exports verified') + +// Verify barrel vended-plugins exports are constructible +if (typeof AgentSkills !== 'function' || typeof ContextOffloader !== 'function' || typeof InMemoryStorage !== 'function') { + throw new Error('Barrel vended-plugins exports are not constructible') +} +console.log('✓ Barrel vended-plugins exports verified') diff --git a/strands-ts/test/packages/esm-module/package.json b/strands-ts/test/packages/esm-module/package.json new file mode 100644 index 0000000000..467522fccd --- /dev/null +++ b/strands-ts/test/packages/esm-module/package.json @@ -0,0 +1,10 @@ +{ + "type": "module", + "name": "test-package", + "version": "1.0.0", + "private": true, + "description": "Test package to verify SDK works without bundler", + "dependencies": { + "@strands-agents/sdk": "file:../../.." + } +} \ No newline at end of file diff --git a/strands-ts/test/packages/npm-pack/package.json b/strands-ts/test/packages/npm-pack/package.json new file mode 100644 index 0000000000..ec42d9c7c2 --- /dev/null +++ b/strands-ts/test/packages/npm-pack/package.json @@ -0,0 +1,10 @@ +{ + "//": "Fixture for .github/workflows/test-package-pack.yml. Copied to a tempdir outside the monorepo, then the SDK tarball is installed on top. Only dev tooling lives here — required peer deps are auto-installed from the tarball's peerDependencies metadata; optional peers are deliberately omitted so module-load-time references to them fail the test.", + "private": true, + "type": "module", + "devDependencies": { + "typescript": "^6.0.2", + "tsx": "^4.21.0", + "@types/node": "^25.6.0" + } +} diff --git a/strands-ts/test/packages/npm-pack/tsconfig.json b/strands-ts/test/packages/npm-pack/tsconfig.json new file mode 100644 index 0000000000..e3f5c0d765 --- /dev/null +++ b/strands-ts/test/packages/npm-pack/tsconfig.json @@ -0,0 +1,21 @@ +// Consumer-side type-check config for the packaging smoke test. Runs against +// the installed @strands-agents/sdk tarball's .d.ts surface, not the SDK +// source. Kept minimal on purpose — we want errors in verify.ts, not false +// positives from stricter-than-needed options. +{ + "compilerOptions": { + "target": "ES2022", + "module": "NodeNext", + "moduleResolution": "nodenext", + "lib": ["ES2022"], + "strict": true, + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "skipLibCheck": true, + "resolveJsonModule": true, + "isolatedModules": true, + "types": ["node"], + "noEmit": true + }, + "include": ["verify.ts"] +} diff --git a/strands-ts/test/packages/npm-pack/verify.ts b/strands-ts/test/packages/npm-pack/verify.ts new file mode 100644 index 0000000000..0fb36233b4 --- /dev/null +++ b/strands-ts/test/packages/npm-pack/verify.ts @@ -0,0 +1,140 @@ +/** + * Consumer fixture for .github/workflows/test-package-pack.yml. Runs against + * the packed tarball with only non-optional peers installed, so any import + * that transitively pulls an optional peer fails at module load. + * + * Subpaths deliberately NOT imported because they require optional peers: + * models/{anthropic,openai,google,vercel}, a2a, a2a/express, + * session/s3-storage, telemetry. Those are covered by the sibling + * `../esm-module` and `../cjs-module` suites. + */ + +import { + Agent, + AgentResult, + BedrockModel, + ContextWindowOverflowError, + FunctionTool, + Model, + StateStore, + Tool, + ZodTool, + tool, +} from '@strands-agents/sdk' + +import { notebook } from '@strands-agents/sdk/vended-tools/notebook' +import { fileEditor } from '@strands-agents/sdk/vended-tools/file-editor' +import { httpRequest } from '@strands-agents/sdk/vended-tools/http-request' +import { bash } from '@strands-agents/sdk/vended-tools/bash' + +import { + bash as barrelBash, + fileEditor as barrelFileEditor, + httpRequest as barrelHttpRequest, + notebook as barrelNotebook, +} from '@strands-agents/sdk/vended-tools' + +import { + AgentSkills as BarrelAgentSkills, + ContextOffloader as BarrelContextOffloader, + InMemoryStorage as BarrelInMemoryStorage, +} from '@strands-agents/sdk/vended-plugins' + +import { BedrockModel as BedrockFromSubpath } from '@strands-agents/sdk/models/bedrock' +import { Graph, Swarm, MultiAgentState } from '@strands-agents/sdk/multiagent' +import { AgentSkills } from '@strands-agents/sdk/vended-plugins/skills' +import { ContextOffloader, InMemoryStorage } from '@strands-agents/sdk/vended-plugins/context-offloader' + +import { z } from 'zod' + +console.log('[pack-test] Imports resolved') + +const model = new BedrockModel({ region: 'us-west-2' }) +if (!model.getConfig()) { + throw new Error('BedrockModel config is invalid') +} +console.log('[pack-test] BedrockModel constructed') + +const weatherTool = tool({ + name: 'get_weather', + description: 'Get the current weather for a specific location.', + inputSchema: z.object({ + location: z.string().describe('The city and state, e.g., San Francisco, CA'), + }), + callback: (input) => `The weather in ${input.location} is 72F and sunny.`, +}) + +const response = await weatherTool.invoke({ location: 'New York' }) +if (response !== 'The weather in New York is 72F and sunny.') { + throw new Error(`Tool returned invalid response: ${String(response)}`) +} +console.log('[pack-test] Tool invocation produced expected output') + +const agent = new Agent({ model, tools: [weatherTool] }) +if (agent.tools.length === 0) { + throw new Error('Tool was not correctly added to the agent') +} +console.log('[pack-test] Agent constructed with tool') + +const vendedTools: Record = { notebook, fileEditor, httpRequest, bash } +for (const [name, t] of Object.entries(vendedTools)) { + if (!(t instanceof Tool)) { + throw new Error(`Vended tool '${name}' is not a Tool instance`) + } +} +console.log('[pack-test] All vended tools are Tool instances') + +if (BedrockFromSubpath !== BedrockModel) { + throw new Error('BedrockModel from subpath does not match main export') +} +if (!(model instanceof Model)) { + throw new Error('BedrockModel is not a Model instance') +} +if (!(weatherTool instanceof FunctionTool) && !(weatherTool instanceof ZodTool)) { + throw new Error('tool() factory returned an unexpected Tool subclass') +} +console.log('[pack-test] Subpath export identity + model/tool hierarchy verified') + +const store = new StateStore({ count: 0 }) +store.set('count', 1) +if (store.get('count') !== 1) { + throw new Error('StateStore did not round-trip value') +} +console.log('[pack-test] StateStore round-trip verified') + +const multiAgentState = new MultiAgentState() +if (!(multiAgentState instanceof MultiAgentState)) { + throw new Error('MultiAgentState construction failed') +} +const skills = new AgentSkills({ skills: [] }) +if (!(skills instanceof AgentSkills)) { + throw new Error('AgentSkills construction failed') +} +const offloader = new ContextOffloader({ storage: new InMemoryStorage() }) +if (!(offloader instanceof ContextOffloader)) { + throw new Error('ContextOffloader construction failed') +} +for (const [name, ctor] of Object.entries({ Graph, Swarm })) { + if (typeof ctor !== 'function') { + throw new Error(`${name} subpath export is not a constructor`) + } +} +console.log('[pack-test] multiagent + vended-plugin subpaths constructible') + +const ctxErr = new ContextWindowOverflowError('test') +if (!(ctxErr instanceof Error)) { + throw new Error('ContextWindowOverflowError is not an Error subclass') +} + +void AgentResult +console.log('[pack-test] Error + result types importable') + +if (barrelBash !== bash || barrelFileEditor !== fileEditor || barrelHttpRequest !== httpRequest || barrelNotebook !== notebook) { + throw new Error('Barrel vended-tools exports do not match individual subpath exports') +} +if (BarrelAgentSkills !== AgentSkills || BarrelContextOffloader !== ContextOffloader || BarrelInMemoryStorage !== InMemoryStorage) { + throw new Error('Barrel vended-plugins exports do not match individual subpath exports') +} +console.log('[pack-test] barrel exports match individual subpath exports') + +console.log('[pack-test] OK') diff --git a/strands-ts/tsconfig.base.json b/strands-ts/tsconfig.base.json new file mode 100644 index 0000000000..b8354b6299 --- /dev/null +++ b/strands-ts/tsconfig.base.json @@ -0,0 +1,32 @@ +{ + "compilerOptions": { + "target": "ES2022", + "module": "NodeNext", + "moduleResolution": "nodenext", + "lib": ["ES2022", "DOM", "DOM.Iterable"], + "composite": true, + "allowJs": false, + "declaration": true, + "declarationMap": true, + "outDir": "./dist", + "rootDir": ".", + "strict": true, + "noImplicitAny": true, + "strictNullChecks": true, + "strictFunctionTypes": true, + "noImplicitReturns": true, + "noFallthroughCasesInSwitch": true, + "noUncheckedIndexedAccess": true, + "exactOptionalPropertyTypes": true, + "esModuleInterop": true, + "allowSyntheticDefaultImports": true, + "forceConsistentCasingInFileNames": true, + "skipLibCheck": true, + "resolveJsonModule": true, + "isolatedModules": true, + "verbatimModuleSyntax": true, + "sourceMap": true, + "removeComments": false, + "types": ["vite/client", "@types/node"] + } +} diff --git a/strands-ts/vitest.config.ts b/strands-ts/vitest.config.ts new file mode 100644 index 0000000000..657ebc62d0 --- /dev/null +++ b/strands-ts/vitest.config.ts @@ -0,0 +1,130 @@ +import { defineConfig } from 'vitest/config' +import { playwright } from '@vitest/browser-playwright' +import * as path from 'node:path' +import { fileURLToPath } from 'url' + +const __dirname = path.dirname(fileURLToPath(import.meta.url)) + +// Conditionally exclude bash tool from coverage on Windows +// since tests are skipped on Windows (bash not available) +const coverageExclude = ['src/**/__tests__/**', 'src/**/__fixtures__/**', 'src/vended-tools/**/__tests__/**', 'src/vended-plugins/**/__tests__/**'] +if (process.platform === 'win32') { + coverageExclude.push('src/vended-tools/bash/**') +} + +export default defineConfig({ + test: { + unstubEnvs: true, + reporters: [ + 'default', + ['junit', { outputFile: 'test/.artifacts/test-report/junit/report.xml', includeConsoleOutput: true }], + ['json', { outputFile: 'test/.artifacts/test-report/json/report.json' }], + ], + projects: [ + { + test: { + include: [ + 'src/**/__tests__/**/*.test.ts', + 'src/**/__tests__/**/*.test.node.ts', + 'src/vended-tools/**/__tests__/**/*.test.ts', + 'src/vended-tools/**/__tests__/**/*.test.node.ts', + 'src/vended-plugins/**/__tests__/**/*.test.ts', + 'src/vended-plugins/**/__tests__/**/*.test.node.ts', + ], + name: { label: 'unit-node', color: 'green' }, + typecheck: { + enabled: true, + tsconfig: 'src/tsconfig.json', + include: ['src/**/__tests__**/*.test-d.ts'], + }, + }, + }, + { + test: { + include: [ + 'src/**/__tests__/**/*.test.ts', + 'src/**/__tests__/**/*.test.browser.ts', + 'src/vended-tools/**/__tests__/**/*.test.ts', + 'src/vended-tools/**/__tests__/**/*.test.browser.ts', + 'src/vended-plugins/**/__tests__/**/*.test.ts', + 'src/vended-plugins/**/__tests__/**/*.test.browser.ts', + ], + name: { label: 'unit-browser', color: 'cyan' }, + browser: { + enabled: true, + provider: playwright(), + headless: true, + screenshotDirectory: 'test/.artifacts/browser-screenshots/', + instances: [ + { + browser: 'chromium', + }, + ], + }, + }, + }, + { + test: { + alias: { + '$/sdk': path.resolve(__dirname, './src'), + '$/vended': path.resolve(__dirname, './src/vended-tools'), + }, + include: ['test/integ/**/*.test.ts', 'test/integ/**/*.test.node.ts'], + name: { label: 'integ-node', color: 'magenta' }, + testTimeout: 60 * 1000, + retry: 1, + globalSetup: './test/integ/__fixtures__/_setup-global.ts', + setupFiles: './test/integ/__fixtures__/_setup-test.ts', + sequence: { + concurrent: true, + }, + }, + }, + { + test: { + alias: { + '$/sdk': path.resolve(__dirname, './src'), + '$/vended': path.resolve(__dirname, './src/vended-tools'), + }, + include: ['test/integ/**/*.test.ts', 'test/integ/**/*.test.browser.ts'], + name: { label: 'integ-browser', color: 'yellow' }, + testTimeout: 60 * 1000, + retry: 1, + browser: { + enabled: true, + provider: playwright(), + headless: true, + screenshotDirectory: 'test/.artifacts/browser-screenshots/', + instances: [ + { + browser: 'chromium', + }, + ], + }, + globalSetup: './test/integ/__fixtures__/_setup-global.ts', + setupFiles: './test/integ/__fixtures__/_setup-test.ts', + sequence: { + concurrent: true, + }, + }, + }, + ], + typecheck: { + enabled: true, + }, + coverage: { + provider: 'v8', + reporter: ['text', 'json', 'html'], + reportsDirectory: 'test/.artifacts/coverage', + include: ['src/**/*.{ts,js}', 'src/vended-tools/**/*.{ts,js}', 'src/vended-plugins/**/*.{ts,js}'], + exclude: coverageExclude, + thresholds: { + lines: 80, + functions: 80, + branches: 80, + statements: 80, + }, + }, + environment: 'node', + }, +}) diff --git a/strands-wasm/README.md b/strands-wasm/README.md new file mode 100644 index 0000000000..64a32e7e70 --- /dev/null +++ b/strands-wasm/README.md @@ -0,0 +1,139 @@ +# strands-wasm + +WASM build tooling and monorepo developer guide. Describes the WebAssembly component architecture, build pipeline, WIT contracts, and cross-package development workflow. + +## How it works + +The TypeScript SDK is compiled into a WebAssembly component (`strands-agent.wasm`). Python loads this component via wasmtime-py and drives it. + +The WIT contract (`wit/agent.wit`) defines what crosses the WASM boundary: + +- **Exports** (TS implements, Python calls): The `api` interface — agent construction, streaming, conversation management. All model provider HTTP calls (Bedrock, Anthropic, OpenAI, Gemini) happen inside the WASM guest. +- **Imports** (Python implements, TS calls back into): `tool-provider` for executing Python-defined tools, and `host-log` for routing log entries to Python's logging framework. + +In WIT terminology, the WASM component is the "guest" and Python is the "host". When the TS agent loop decides a tool needs to run, it calls the `tool-provider` import which crosses the WASM boundary back to Python where the actual tool function lives. + +## Getting started + +### Prerequisites + +- Node.js 20+ +- Python 3.10+ +- [wasmtime-py](https://github.com/bytecodealliance/wasmtime-py) (forked build with async component model support) + +### First-time setup + +```bash +git clone https://github.com/strands-agents/sdk-typescript.git +cd sdk-typescript +npm install +npm run dev -- bootstrap +``` + +`bootstrap` installs toolchains, generates type bindings, builds all layers, and runs all tests. If this command doesn't enable development out of the box, file an issue. + +## Architecture + +### Build pipeline + +Changes flow through a pipeline. Each layer compiles into the next: + +```mermaid +graph TD + WIT["wit/agent.wit"] -->|generate| TS_GEN["strands-ts/generated/"] + WIT -->|generate| WASM_GEN["strands-wasm/generated/"] + + TS_GEN --> TS["strands-ts (npm build)"] + TS -->|esbuild bundle| WASM_BUNDLE["strands-wasm (ESM bundle)"] + WASM_GEN --> WASM_BUNDLE + WASM_BUNDLE -->|componentize-js| WASM["agent.wasm (WASM component)"] + WASM -->|wasmtime-py| PY["strands-py-wasm (Python package)"] +``` + +| Directory | Language | What it is | +| -------------- | ---------- | ------------------------------------------------------------------- | +| `wit/` | WIT | Interface contract between the WASM guest and host | +| `strands-ts/` | TypeScript | Agent runtime: event loop, model providers, tools, hooks, streaming | +| `strands-wasm/` | TypeScript | Bridges the TS SDK to WIT exports, compiles to a WASM component | +| `strands-py-wasm/` | Python | Python wrapper: Agent class, @tool decorator, direct WASM host | +| `strands-dev/` | TypeScript | Dev CLI that orchestrates build, test, lint, and CI | +| `dev-docs/` | Markdown | Design proposal and team decisions | + +### Generated code + +`npm run dev -- generate` produces type bindings from `wit/agent.wit` into: + +- `strands-ts/generated/` +- `strands-wasm/generated/` + +Generated files are created by running `npm run dev -- generate` (or `bootstrap`) and are gitignored. Do not edit them by hand. CI runs `generate --check` and fails if they are stale. + +Python types are auto-generated into `strands-py-wasm/strands/_generated/types.py` by `strands-py-wasm/scripts/generate_types.py`. + +### Tests + +| Layer | Framework | Location | +| -------------- | --------- | ----------------------------------------------------------------- | +| TypeScript SDK | vitest | `strands-ts/src/**/__tests__/` (unit), `strands-ts/test/` (integ) | +| Python wrapper | pytest | `strands-py-wasm/tests_integ/` | + +Add tests alongside the code you change. Bug fixes should include a test that reproduces the original issue. + +## Making changes + +Each layer depends on the layers above it in the pipeline. The `validate` command rebuilds and tests exactly the layers your change affects. + +| What you changed | Validate command | +| ------------------------------------- | ------------------------------------- | +| WIT contract (`wit/agent.wit`) | `npm run dev -- validate wit` | +| TS SDK internals | `npm run dev -- validate ts` | +| TS SDK public API | `npm run dev -- validate ts-api` | +| WASM bridge (`strands-wasm/entry.ts`) | `npm run dev -- validate wasm` | +| Pure Python (`strands-py-wasm/`) | `npm run dev -- validate py` | + +**TS internals vs. public API:** The WASM bridge (`strands-wasm/entry.ts`) imports specific types and functions from `strands-ts/`. If your change modifies something the bridge imports, it is a public API change — use `validate ts-api`. If the bridge does not import it, use `validate ts`. + +**WIT contract changes** cascade to every layer. After running `validate wit`, fix any compile errors in `strands-wasm/entry.ts` and the language wrappers. The build will not succeed until every layer matches the new contract. + +## Dev CLI + +```bash +npm run dev -- [options] +``` + +Most commands accept layer flags (`--ts`, `--wasm`, `--py`). No flags means all layers. + +| Command | What it does | +| ------------------ | ---------------------------------------------------------------------- | +| `bootstrap` | First-time setup: install, generate, build, test | +| `setup` | Install toolchains (`--node`, `--python`) | +| `generate` | Regenerate type bindings from WIT (`--check`) | +| `build` | Compile layers (`--ts`, `--wasm`, `--py`, `--release`) | +| `test` | Run tests (`--py`, `--ts`, or a specific `[file]`) | +| `check` | Lint and type-check (`--ts`, `--py`) | +| `fmt` | Format all code (`--check` to verify without writing) | +| `validate ` | Rebuild and test the layers affected by a change | +| `ci` | Full pipeline: generate, format, lint, build, test | +| `rebuild` | Clean rebuild: clean, generate, build | +| `clean` | Remove all build artifacts | +| `example ` | Run an example (`--py`, `--ts`) | + +## Code style + +| Language | Formatter | Linter | +| ---------- | ------------- | -------------- | +| TypeScript | `prettier` | `tsc --noEmit` | +| Python | `ruff format` | `ruff check` | + +```bash +npm run dev -- fmt # format everything +npm run dev -- check # lint everything +``` + +Comments are normative statements that describe what code does or why a decision was made. Avoid TODO's without associated issues, notes-to-self, and parenthetical asides. + +## Submitting a PR + +- Run `npm run dev -- ci` before pushing. This is the same pipeline CI runs. +- Keep PRs focused on a single change. +- Use conventional commit messages: `feat:`, `fix:`, `refactor:`, `docs:`, etc. diff --git a/strands-wasm/__fixtures__/host-log.ts b/strands-wasm/__fixtures__/host-log.ts new file mode 100644 index 0000000000..87ec26969d --- /dev/null +++ b/strands-wasm/__fixtures__/host-log.ts @@ -0,0 +1,2 @@ +import { vi } from 'vitest' +export const log = vi.fn() diff --git a/strands-wasm/__fixtures__/tool-provider.ts b/strands-wasm/__fixtures__/tool-provider.ts new file mode 100644 index 0000000000..d2e291c948 --- /dev/null +++ b/strands-wasm/__fixtures__/tool-provider.ts @@ -0,0 +1,3 @@ +import { vi } from 'vitest' +export const callTool = vi.fn() +export const callTools = vi.fn() diff --git a/strands-wasm/__tests__/lifecycle.test.ts b/strands-wasm/__tests__/lifecycle.test.ts new file mode 100644 index 0000000000..f4d938bba9 --- /dev/null +++ b/strands-wasm/__tests__/lifecycle.test.ts @@ -0,0 +1,160 @@ +import { describe, it, expect } from 'vitest' +import { LifecycleBridge } from '../entry' +import { Agent, FunctionTool } from '@strands-agents/sdk' +import { MockMessageModel } from '$/fixtures/mock-message-model' + +describe('LifecycleBridge', () => { + async function runTextTurn(): Promise { + const bridge = new LifecycleBridge() + const model = new MockMessageModel() + model.addTurn({ type: 'textBlock', text: 'Hello' }) + const agent = new Agent({ model, plugins: [bridge], printer: false }) + await agent.invoke('hello') + return bridge + } + + describe('Plugin interface', () => { + it('has name property', () => { + const bridge = new LifecycleBridge() + expect(bridge.name).toBe('strands:lifecycle-bridge') + }) + + it('has initAgent method', () => { + const bridge = new LifecycleBridge() + expect(typeof bridge.initAgent).toBe('function') + }) + + it('has drain method', () => { + const bridge = new LifecycleBridge() + expect(typeof bridge.drain).toBe('function') + }) + }) + + describe('lifecycle events with simple text response', () => { + it('produces lifecycle events for a text-only agent turn', async () => { + const bridge = await runTextTurn() + const events = bridge.drain() + expect(events.length).toBeGreaterThan(0) + + const eventTypes = events.map((e) => e.val.eventType) + + expect(eventTypes).toContain('initialized') + expect(eventTypes).toContain('before-invocation') + expect(eventTypes).toContain('before-model-call') + expect(eventTypes).toContain('after-model-call') + expect(eventTypes).toContain('message-added') + expect(eventTypes).toContain('after-invocation') + + const initialized = events.find((e) => e.val.eventType === 'initialized') + expect(initialized).toStrictEqual({ + tag: 'lifecycle', + val: { eventType: 'initialized', toolUse: undefined, toolResult: undefined }, + }) + + const beforeInvocation = events.find((e) => e.val.eventType === 'before-invocation') + expect(beforeInvocation).toStrictEqual({ + tag: 'lifecycle', + val: { eventType: 'before-invocation', toolUse: undefined, toolResult: undefined }, + }) + + const beforeModelCall = events.find((e) => e.val.eventType === 'before-model-call') + expect(beforeModelCall).toStrictEqual({ + tag: 'lifecycle', + val: { eventType: 'before-model-call', toolUse: undefined, toolResult: undefined }, + }) + + const afterModelCall = events.find((e) => e.val.eventType === 'after-model-call') + expect(afterModelCall).toStrictEqual({ + tag: 'lifecycle', + val: { eventType: 'after-model-call', toolUse: undefined, toolResult: undefined }, + }) + + const messageAdded = events.find((e) => e.val.eventType === 'message-added') + expect(messageAdded).toStrictEqual({ + tag: 'lifecycle', + val: { eventType: 'message-added', toolUse: undefined, toolResult: undefined }, + }) + + const afterInvocation = events.find((e) => e.val.eventType === 'after-invocation') + expect(afterInvocation).toStrictEqual({ + tag: 'lifecycle', + val: { eventType: 'after-invocation', toolUse: undefined, toolResult: undefined }, + }) + }) + + it('non-tool events have undefined toolUse and toolResult', async () => { + const bridge = await runTextTurn() + const events = bridge.drain() + for (const event of events) { + expect(event.tag).toBe('lifecycle') + expect(event.val.toolUse).toBeUndefined() + expect(event.val.toolResult).toBeUndefined() + } + }) + }) + + describe('drain clears queue', () => { + it('first drain returns events, second drain returns empty array', async () => { + const bridge = await runTextTurn() + + const first = bridge.drain() + expect(first.length).toBeGreaterThan(0) + + const second = bridge.drain() + expect(second).toStrictEqual([]) + }) + }) + + describe('tool-related lifecycle events', () => { + it('produces before-tool-call and after-tool-call events with serialized data', async () => { + const bridge = new LifecycleBridge() + const model = new MockMessageModel() + + model.addTurn({ + type: 'toolUseBlock', + name: 'test_tool', + toolUseId: 'tu-1', + input: { query: 'test' }, + }) + model.addTurn({ type: 'textBlock', text: 'Done' }) + + const tool = new FunctionTool({ + name: 'test_tool', + description: 'A test tool', + inputSchema: { type: 'object', properties: { query: { type: 'string' } } }, + callback: () => [{ text: 'tool result' }], + }) + + const agent = new Agent({ + model, + plugins: [bridge], + tools: [tool], + printer: false, + }) + + await agent.invoke('use the tool') + + const events = bridge.drain() + + const beforeToolCall = events.find((e) => e.val.eventType === 'before-tool-call') + expect(beforeToolCall).toStrictEqual({ + tag: 'lifecycle', + val: { + eventType: 'before-tool-call', + toolUse: expect.any(String), + toolResult: undefined, + }, + }) + + const afterToolCall = events.find((e) => e.val.eventType === 'after-tool-call') + expect(afterToolCall).toStrictEqual({ + tag: 'lifecycle', + val: { + eventType: 'after-tool-call', + toolUse: expect.any(String), + toolResult: expect.any(String), + }, + }) + }) + }) +}) diff --git a/strands-wasm/__tests__/mapping.test.ts b/strands-wasm/__tests__/mapping.test.ts new file mode 100644 index 0000000000..1e8dc53150 --- /dev/null +++ b/strands-wasm/__tests__/mapping.test.ts @@ -0,0 +1,463 @@ +import { describe, it, expect } from 'vitest' +import { + mapUsage, + mapMetrics, + mapStopReasonTag, + mapStopReason, + mapEvent, + mapModelStreamEvent, + mapContentBlock, + mapToolStreamEvent, + parseInput, + parseStructuredOutputSchema, + parseSaveLatestStrategy, +} from '../entry' +import type { AgentStreamEvent, ModelStreamEvent, StopReason } from '@strands-agents/sdk' +import { ToolStreamEvent, ToolUseBlock, ToolResultBlock, TextBlock, ReasoningBlock } from '@strands-agents/sdk' + +describe('mapUsage', () => { + it.each([null, undefined])('returns undefined for %s input', (input) => { + expect(mapUsage(input)).toBeUndefined() + }) + + it('maps all fields correctly', () => { + expect(mapUsage({ inputTokens: 10, outputTokens: 20, totalTokens: 30 })).toStrictEqual({ + inputTokens: 10, + outputTokens: 20, + totalTokens: 30, + cacheReadInputTokens: undefined, + cacheWriteInputTokens: undefined, + }) + }) + + it('computes totalTokens when missing', () => { + expect(mapUsage({ inputTokens: 5, outputTokens: 3 })).toStrictEqual({ + inputTokens: 5, + outputTokens: 3, + totalTokens: 8, + cacheReadInputTokens: undefined, + cacheWriteInputTokens: undefined, + }) + }) + + it('includes cache fields when present', () => { + expect( + mapUsage({ + inputTokens: 10, + outputTokens: 20, + totalTokens: 30, + cacheReadInputTokens: 5, + cacheWriteInputTokens: 2, + }) + ).toStrictEqual({ + inputTokens: 10, + outputTokens: 20, + totalTokens: 30, + cacheReadInputTokens: 5, + cacheWriteInputTokens: 2, + }) + }) +}) + +describe('mapMetrics', () => { + it.each([null, undefined])('returns undefined for %s input', (input) => { + expect(mapMetrics(input)).toBeUndefined() + }) + + it('maps latencyMs', () => { + expect(mapMetrics({ latencyMs: 150 })).toStrictEqual({ latencyMs: 150 }) + }) + + it('defaults latencyMs to 0 when field is absent', () => { + expect(mapMetrics({})).toStrictEqual({ latencyMs: 0 }) + }) + + it('defaults latencyMs to 0 when field is explicitly undefined', () => { + expect(mapMetrics({ latencyMs: undefined })).toStrictEqual({ latencyMs: 0 }) + }) +}) + +describe('mapStopReasonTag', () => { + const mappings: [string, string][] = [ + ['endTurn', 'end-turn'], + ['toolUse', 'tool-use'], + ['maxTokens', 'max-tokens'], + ['contentFiltered', 'content-filtered'], + ['guardrailIntervened', 'guardrail-intervened'], + ['stopSequence', 'stop-sequence'], + ['modelContextWindowExceeded', 'model-context-window-exceeded'], + ['cancelled', 'cancelled'], + ] + + it.each(mappings)("maps '%s' to '%s'", (input, expected) => { + expect(mapStopReasonTag(input as StopReason)).toBe(expected) + }) + + it("maps unknown reason to 'error'", () => { + expect(mapStopReasonTag('unknownReason' as unknown as StopReason)).toBe('error') + }) + + it('covers every WIT StopReason variant except error', () => { + const witStopReasons = [ + 'end-turn', + 'tool-use', + 'max-tokens', + 'error', + 'content-filtered', + 'guardrail-intervened', + 'stop-sequence', + 'model-context-window-exceeded', + 'cancelled', + ] + const mappedOutputs = mappings.map(([, wit]) => wit) + const nonErrorVariants = witStopReasons.filter((r) => r !== 'error') + expect(mappedOutputs.sort()).toStrictEqual(nonErrorVariants.sort()) + }) +}) + +describe('mapStopReason', () => { + it('maps reason with no agent result', () => { + expect(mapStopReason('endTurn')).toStrictEqual({ + reason: 'end-turn', + usage: undefined, + metrics: undefined, + structuredOutput: undefined, + }) + }) + + it('maps reason with usage and metrics', () => { + expect( + mapStopReason('toolUse', { + usage: { inputTokens: 1, outputTokens: 2, totalTokens: 3 }, + metrics: { latencyMs: 100 }, + }) + ).toStrictEqual({ + reason: 'tool-use', + usage: { + inputTokens: 1, + outputTokens: 2, + totalTokens: 3, + cacheReadInputTokens: undefined, + cacheWriteInputTokens: undefined, + }, + metrics: { latencyMs: 100 }, + structuredOutput: undefined, + }) + }) + + it('serializes structured output as JSON string', () => { + expect( + mapStopReason('endTurn', { + structuredOutput: { name: 'Alice', age: 30 }, + }) + ).toStrictEqual({ + reason: 'end-turn', + usage: undefined, + metrics: undefined, + structuredOutput: '{"name":"Alice","age":30}', + }) + }) + + it('sets structuredOutput to undefined when not present', () => { + expect( + mapStopReason('endTurn', { usage: { inputTokens: 5, outputTokens: 10, totalTokens: 15 } }) + ).toStrictEqual({ + reason: 'end-turn', + usage: { + inputTokens: 5, + outputTokens: 10, + totalTokens: 15, + cacheReadInputTokens: undefined, + cacheWriteInputTokens: undefined, + }, + metrics: undefined, + structuredOutput: undefined, + }) + }) +}) + +describe('mapModelStreamEvent', () => { + it('maps text delta', () => { + const event: ModelStreamEvent = { type: 'modelContentBlockDeltaEvent', delta: { type: 'textDelta', text: 'hello' } } + expect(mapModelStreamEvent(event)).toStrictEqual({ tag: 'text-delta', val: 'hello' }) + }) + + it('returns null for non-text delta', () => { + const event: ModelStreamEvent = { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'toolUseInputDelta', input: '{}' }, + } + expect(mapModelStreamEvent(event)).toBeNull() + }) + + it('returns null for reasoningContentDelta', () => { + const event: ModelStreamEvent = { + type: 'modelContentBlockDeltaEvent', + delta: { type: 'reasoningContentDelta', text: 'thinking...' }, + } + expect(mapModelStreamEvent(event)).toBeNull() + }) + + it('maps modelContentBlockStartEvent with toolUseStart', () => { + const event: ModelStreamEvent = { + type: 'modelContentBlockStartEvent', + start: { type: 'toolUseStart', name: 'calc', toolUseId: 'tu-5' }, + } + expect(mapModelStreamEvent(event)).toStrictEqual({ + tag: 'tool-use', + val: { name: 'calc', toolUseId: 'tu-5', input: '{}' }, + }) + }) + + it('returns null for modelContentBlockStartEvent without start', () => { + const event: ModelStreamEvent = { type: 'modelContentBlockStartEvent' } + expect(mapModelStreamEvent(event)).toBeNull() + }) + + it('maps modelMetadataEvent', () => { + const event: ModelStreamEvent = { + type: 'modelMetadataEvent', + usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 }, + metrics: { latencyMs: 50 }, + } + expect(mapModelStreamEvent(event)).toStrictEqual({ + tag: 'metadata', + val: { + usage: { + inputTokens: 10, + outputTokens: 20, + totalTokens: 30, + cacheReadInputTokens: undefined, + cacheWriteInputTokens: undefined, + }, + metrics: { latencyMs: 50 }, + }, + }) + }) + + it('returns null for unrecognized model event', () => { + const event: ModelStreamEvent = { type: 'modelMessageStartEvent', role: 'assistant' } + expect(mapModelStreamEvent(event)).toBeNull() + }) + + it('maps modelMetadataEvent without usage or metrics', () => { + const event: ModelStreamEvent = { type: 'modelMetadataEvent' } + expect(mapModelStreamEvent(event)).toStrictEqual({ + tag: 'metadata', + val: { usage: undefined, metrics: undefined }, + }) + }) +}) + +describe('mapContentBlock', () => { + it('maps toolUseBlock', () => { + const block = new ToolUseBlock({ name: 'calc', toolUseId: 'tu-1', input: { x: 1 } }) + expect(mapContentBlock(block)).toStrictEqual({ + tag: 'tool-use', + val: { name: 'calc', toolUseId: 'tu-1', input: '{"x":1}' }, + }) + }) + + it('maps toolUseBlock with null input to empty object', () => { + const block = new ToolUseBlock({ name: 'calc', toolUseId: 'tu-1', input: null }) + expect(mapContentBlock(block)).toStrictEqual({ + tag: 'tool-use', + val: { name: 'calc', toolUseId: 'tu-1', input: '{}' }, + }) + }) + + it('maps toolResultBlock', () => { + const block = new ToolResultBlock({ + toolUseId: 'tu-1', + status: 'success', + content: [new TextBlock('ok')], + }) + expect(mapContentBlock(block)).toStrictEqual({ + tag: 'tool-result', + val: { toolUseId: 'tu-1', status: 'success', content: '[{"text":"ok"}]' }, + }) + }) + + it('returns null for textBlock', () => { + const block = new TextBlock('hello') + expect(mapContentBlock(block)).toBeNull() + }) + + it('returns null for reasoningBlock', () => { + const block = new ReasoningBlock({ text: '' }) + expect(mapContentBlock(block)).toBeNull() + }) +}) + +describe('mapToolStreamEvent', () => { + it('maps event with data', () => { + const event = new ToolStreamEvent({ data: { value: 42 } }) + expect(mapToolStreamEvent(event)).toStrictEqual({ + tag: 'tool-result', + val: { toolUseId: '', status: 'success', content: '{"data":{"value":42}}' }, + }) + }) + + it('maps event without data', () => { + const event = new ToolStreamEvent({}) + expect(mapToolStreamEvent(event)).toStrictEqual({ + tag: 'tool-result', + val: { toolUseId: '', status: 'success', content: '{"data":null}' }, + }) + }) + + it('maps event with string data', () => { + const event = new ToolStreamEvent({ data: 'processing step 1' }) + expect(mapToolStreamEvent(event)).toStrictEqual({ + tag: 'tool-result', + val: { toolUseId: '', status: 'success', content: '{"data":"processing step 1"}' }, + }) + }) +}) + +describe('mapEvent', () => { + describe('wrapper events', () => { + it('unwraps modelStreamUpdateEvent to mapModelStreamEvent', () => { + const event = { + type: 'modelStreamUpdateEvent', + event: { type: 'modelContentBlockDeltaEvent', delta: { type: 'textDelta', text: 'wrapped' } }, + } as unknown as AgentStreamEvent + expect(mapEvent(event)).toStrictEqual({ tag: 'text-delta', val: 'wrapped' }) + }) + + it('unwraps contentBlockEvent to mapContentBlock', () => { + const event = { + type: 'contentBlockEvent', + contentBlock: new ToolUseBlock({ name: 'tool1', toolUseId: 'tu-2', input: {} }), + } as unknown as AgentStreamEvent + expect(mapEvent(event)).toStrictEqual({ + tag: 'tool-use', + val: { name: 'tool1', toolUseId: 'tu-2', input: '{}' }, + }) + }) + + it('unwraps toolResultEvent to mapContentBlock', () => { + const event = { + type: 'toolResultEvent', + result: new ToolResultBlock({ toolUseId: 'tu-3', status: 'error', content: [] }), + } as unknown as AgentStreamEvent + expect(mapEvent(event)).toStrictEqual({ + tag: 'tool-result', + val: { toolUseId: 'tu-3', status: 'error', content: '[]' }, + }) + }) + + it('unwraps toolStreamUpdateEvent to mapToolStreamEvent', () => { + const event = { + type: 'toolStreamUpdateEvent', + event: new ToolStreamEvent({ data: { progress: 50 } }), + } as unknown as AgentStreamEvent + expect(mapEvent(event)).toStrictEqual({ + tag: 'tool-result', + val: { toolUseId: '', status: 'success', content: '{"data":{"progress":50}}' }, + }) + }) + }) + + describe('special events', () => { + it('maps interrupt event', () => { + const event = { interrupt: { reason: 'user' } } + expect(mapEvent(event as unknown as AgentStreamEvent)).toStrictEqual({ + tag: 'interrupt', + val: JSON.stringify(event), + }) + }) + + it('does not treat hook events with interrupt() method as interrupt stream events', () => { + const event = { type: 'beforeToolCallEvent', interrupt: () => {} } + expect(mapEvent(event as unknown as AgentStreamEvent)).toBeNull() + }) + }) + + describe('dropped events', () => { + it.each([ + 'beforeInvocationEvent', + 'afterInvocationEvent', + 'beforeModelCallEvent', + 'afterModelCallEvent', + 'beforeToolCallEvent', + 'afterToolCallEvent', + 'messageAddedEvent', + 'modelMessageEvent', + 'agentResultEvent', + 'beforeToolsEvent', + 'afterToolsEvent', + ])('returns null for %s', (type) => { + const event = { type } as unknown as AgentStreamEvent + expect(mapEvent(event)).toBeNull() + }) + }) +}) + +describe('parseInput', () => { + it('returns parsed array for JSON array input', () => { + expect(parseInput('[{"type":"text","text":"hi"}]')).toStrictEqual([{ type: 'text', text: 'hi' }]) + }) + + it('returns string for plain text', () => { + expect(parseInput('hello world')).toBe('hello world') + }) + + it('returns original string for JSON object (non-array)', () => { + expect(parseInput('{"key":"value"}')).toBe('{"key":"value"}') + }) + + it('returns empty string for empty input', () => { + expect(parseInput('')).toBe('') + }) + + it('returns original string for malformed JSON', () => { + expect(parseInput('{bad json')).toBe('{bad json') + }) +}) + +describe('parseSaveLatestStrategy', () => { + it.each(['message', 'invocation', 'trigger'] as const)("accepts valid strategy '%s'", (strategy) => { + expect(parseSaveLatestStrategy(strategy)).toBe(strategy) + }) + + it('returns undefined for unknown strategy', () => { + expect(parseSaveLatestStrategy('unknown')).toBeUndefined() + }) + + it('returns undefined for undefined input', () => { + expect(parseSaveLatestStrategy(undefined)).toBeUndefined() + }) + + it('returns undefined for empty string', () => { + expect(parseSaveLatestStrategy('')).toBeUndefined() + }) +}) + +describe('parseStructuredOutputSchema', () => { + it('returns undefined for undefined input', () => { + expect(parseStructuredOutputSchema(undefined)).toBeUndefined() + }) + + it('returns undefined for empty string', () => { + expect(parseStructuredOutputSchema('')).toBeUndefined() + }) + + it('parses a valid JSON schema into a Zod schema', () => { + const schema = parseStructuredOutputSchema( + JSON.stringify({ type: 'object', properties: { name: { type: 'string' } }, required: ['name'] }) + ) + expect(schema).toBeDefined() + expect(schema!.parse({ name: 'Alice' })).toStrictEqual({ name: 'Alice' }) + }) + + it('throws on invalid JSON', () => { + expect(() => parseStructuredOutputSchema('not valid json')).toThrow('Invalid structured output schema') + }) + + it('throws on invalid schema', () => { + expect(() => parseStructuredOutputSchema(JSON.stringify({ type: 'invalid_type_xyz' }))).toThrow( + 'Invalid structured output schema' + ) + }) +}) diff --git a/strands-wasm/__tests__/stream.test.ts b/strands-wasm/__tests__/stream.test.ts new file mode 100644 index 0000000000..7e2108df57 --- /dev/null +++ b/strands-wasm/__tests__/stream.test.ts @@ -0,0 +1,183 @@ +import { describe, it, expect, vi } from 'vitest' +import { api, LifecycleBridge } from '../entry' +import { Agent } from '@strands-agents/sdk' +import { MockMessageModel } from '$/fixtures/mock-message-model' + +const ResponseStream = api.ResponseStream + +function createAgent(): Agent { + return new Agent({ model: new MockMessageModel(), printer: false }) +} + +const beforeModelCallEvent = { + tag: 'lifecycle', + val: { eventType: 'before-model-call', toolUse: undefined, toolResult: undefined }, +} + +const afterInvocationEvent = { + tag: 'lifecycle', + val: { eventType: 'after-invocation', toolUse: undefined, toolResult: undefined }, +} + +function setupStream( + genFn: () => AsyncGenerator, + preQueued?: any[] +): { stream: InstanceType; bridge: LifecycleBridge } { + const agent = createAgent() + const bridge = new LifecycleBridge() + if (preQueued) bridge.queue.push(...preQueued) + vi.spyOn(agent, 'stream').mockReturnValue(genFn()) + return { stream: new ResponseStream(agent, 'test', bridge), bridge } +} + +describe('ResponseStreamImpl.readNext', () => { + describe('mid-stream batch', () => { + it('returns lifecycle events interleaved with mapped event', async () => { + const { stream } = setupStream( + async function* () { + yield { + type: 'modelStreamUpdateEvent', + event: { type: 'modelContentBlockDeltaEvent', delta: { type: 'textDelta', text: 'hello' } }, + } + }, + [beforeModelCallEvent] + ) + const batch = await stream.readNext() + + expect(batch).toStrictEqual([beforeModelCallEvent, { tag: 'text-delta', val: 'hello' }]) + }) + + it('returns empty array when no events to report', async () => { + const { stream } = setupStream(async function* () { + yield { type: 'unknownEvent' } + }) + const batch = await stream.readNext() + + expect(batch).toStrictEqual([]) + }) + }) + + describe('final batch', () => { + it('returns stop event when generator completes with result', async () => { + const { stream } = setupStream(async function* () { + return { + stopReason: 'endTurn', + metrics: { + accumulatedUsage: { inputTokens: 1, outputTokens: 2, totalTokens: 3 }, + accumulatedMetrics: { latencyMs: 100 }, + }, + } + }) + const batch = await stream.readNext() + + expect(batch).toStrictEqual([ + { + tag: 'stop', + val: { + reason: 'end-turn', + usage: { + inputTokens: 1, + outputTokens: 2, + totalTokens: 3, + cacheReadInputTokens: undefined, + cacheWriteInputTokens: undefined, + }, + metrics: { latencyMs: 100 }, + structuredOutput: undefined, + }, + }, + ]) + }) + + it('returns lifecycle events when generator completes with no result but has pending lifecycle events', async () => { + const { stream } = setupStream( + async function* () { + return undefined + }, + [afterInvocationEvent] + ) + const batch = await stream.readNext() + + expect(batch).toStrictEqual([afterInvocationEvent]) + }) + + it('returns undefined when generator completes with no result and no lifecycle events', async () => { + const { stream } = setupStream(async function* () { + return undefined + }) + const batch = await stream.readNext() + + expect(batch).toBeUndefined() + }) + }) + + describe('error batch', () => { + it('returns lifecycle events with error event when generator throws', async () => { + const { stream } = setupStream( + async function* () { + throw new Error('model failed') + }, + [beforeModelCallEvent] + ) + const batch = await stream.readNext() + + expect(batch).toStrictEqual([beforeModelCallEvent, { tag: 'error', val: 'model failed' }]) + }) + }) + + describe('done state', () => { + it('returns undefined after stream is done', async () => { + const { stream } = setupStream(async function* () { + return { stopReason: 'endTurn' } + }) + await stream.readNext() + const batch = await stream.readNext() + + expect(batch).toBeUndefined() + }) + }) + + describe('cancel', () => { + it('cancel sets done state', async () => { + const { stream } = setupStream(async function* () { + yield { + type: 'modelStreamUpdateEvent', + event: { type: 'modelContentBlockDeltaEvent', delta: { type: 'textDelta', text: 'hello' } }, + } + yield { + type: 'modelStreamUpdateEvent', + event: { type: 'modelContentBlockDeltaEvent', delta: { type: 'textDelta', text: 'world' } }, + } + }) + stream.cancel() + const batch = await stream.readNext() + + expect(batch).toBeUndefined() + }) + + it('cancel restores default tools and model', async () => { + const agent = createAgent() + const bridge = new LifecycleBridge() + const defaultTools = [{ name: 'default_tool' }] as any[] + const clearSpy = vi.spyOn(agent.toolRegistry, 'clear') + const addSpy = vi.spyOn(agent.toolRegistry, 'add') + + vi.spyOn(agent, 'stream').mockReturnValue( + (async function* () { + yield { + type: 'modelStreamUpdateEvent', + event: { type: 'modelContentBlockDeltaEvent', delta: { type: 'textDelta', text: 'hello' } }, + } + })() + ) + + const stream = new ResponseStream(agent, 'test', bridge, defaultTools) + stream.cancel() + + const batch = await stream.readNext() + expect(batch).toBeUndefined() + expect(clearSpy).toHaveBeenCalled() + expect(addSpy).toHaveBeenCalledWith(defaultTools) + }) + }) +}) diff --git a/strands-wasm/__tests__/tool-bridge.test.ts b/strands-wasm/__tests__/tool-bridge.test.ts new file mode 100644 index 0000000000..77b9f9d4c2 --- /dev/null +++ b/strands-wasm/__tests__/tool-bridge.test.ts @@ -0,0 +1,105 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { createTools } from '../entry' +import { callTool } from 'strands:agent/tool-provider' + +const emptyToolContext = { toolUse: { toolUseId: '' } } as any + +describe('createTools', () => { + describe('spec handling', () => { + it('returns undefined for undefined specs', () => { + expect(createTools(undefined)).toStrictEqual(undefined) + }) + + it('returns undefined for empty array', () => { + expect(createTools([])).toStrictEqual(undefined) + }) + + it('creates FunctionTool with correct properties', () => { + const specs = [ + { + name: 'calculator', + description: 'Does math', + inputSchema: '{"type":"object","properties":{"expression":{"type":"string"}}}', + }, + ] + const tools = createTools(specs) + expect(tools).toHaveLength(1) + expect(tools![0].name).toBe('calculator') + expect(tools![0].description).toBe('Does math') + }) + + it('parses inputSchema from JSON string', () => { + const tools = createTools([{ name: 'x', description: 'y', inputSchema: '{"type":"object"}' }]) + expect(tools![0].toolSpec.inputSchema).toStrictEqual({ type: 'object' }) + }) + }) + + describe('callback behavior', () => { + const makeTools = (name = 'calc') => createTools([{ name, description: 'math', inputSchema: '{"type":"object"}' }])! + + beforeEach(() => { + vi.mocked(callTool).mockReset() + }) + + it('calls callTool with correct args', async () => { + const tools = makeTools() + vi.mocked(callTool).mockReturnValue('{"result": 42}') + const toolContext = { toolUse: { toolUseId: 'tu-123' } } + await tools[0].invoke({ expression: '1+1' }, toolContext) + expect(callTool).toHaveBeenCalledWith({ + name: 'calc', + input: '{"expression":"1+1"}', + toolUseId: 'tu-123', + }) + }) + + it('strips {status, content} wrapper from host result', async () => { + const tools = makeTools() + vi.mocked(callTool).mockReturnValue(JSON.stringify({ status: 'success', content: [{ text: 'hello' }] })) + const result = await tools[0].invoke({}, emptyToolContext) + expect(result).toStrictEqual([{ text: 'hello' }]) + }) + + it('handles WIT Result ok variant', async () => { + const tools = makeTools() + vi.mocked(callTool).mockReturnValue({ + tag: 'ok', + val: JSON.stringify({ status: 'success', content: [{ text: 'ok' }] }), + }) + const result = await tools[0].invoke({}, emptyToolContext) + expect(result).toStrictEqual([{ text: 'ok' }]) + }) + + it('throws on WIT Result err variant', async () => { + const tools = makeTools() + vi.mocked(callTool).mockReturnValue({ tag: 'err', val: 'tool failed' }) + await expect(tools[0].invoke({}, emptyToolContext)).rejects.toThrow('tool failed') + }) + + it('propagates host exceptions', async () => { + const tools = makeTools() + vi.mocked(callTool).mockImplementation(() => { + throw new Error('host crashed') + }) + await expect(tools[0].invoke({}, emptyToolContext)).rejects.toThrow('host crashed') + }) + + it('uses empty string for toolUseId when context is missing', async () => { + const tools = makeTools() + vi.mocked(callTool).mockReturnValue('{"value": 1}') + await tools[0].invoke({ x: 1 }, emptyToolContext) + expect(callTool).toHaveBeenCalledWith({ + name: 'calc', + input: '{"x":1}', + toolUseId: '', + }) + }) + + it('returns parsed result directly when not a {status, content} wrapper', async () => { + const tools = makeTools() + vi.mocked(callTool).mockReturnValue('{"custom": "data"}') + const result = await tools[0].invoke({}, emptyToolContext) + expect(result).toStrictEqual({ custom: 'data' }) + }) + }) +}) diff --git a/strands-wasm/build.js b/strands-wasm/build.js new file mode 100644 index 0000000000..f2690d725f --- /dev/null +++ b/strands-wasm/build.js @@ -0,0 +1,59 @@ +/** + * Build script for the strands-agent WASM component. + * + * Steps: + * 1. esbuild – bundle entry.ts + full SDK into a single ESM file + * 2. componentize – compile the bundle into a WASM component + * targeting the `agent` world (exports strands:agent/api directly) + * + * Prerequisites: + * - npm install at the workspace root + * - @strands-agents/sdk must be built first (npm run build:sdk) + * + * Key build flags: + * --platform=browser AWS SDK uses fetch instead of node:http + * --define:import.meta.vitest=undefined + * StarlingMonkey throws on unknown import.meta + * properties; the SDK uses import.meta.vitest + * for in-source tests that must be eliminated. + */ + +import { mkdirSync, readFileSync, writeFileSync } from 'node:fs'; +import { resolve } from 'node:path'; +import { build } from 'esbuild'; +import { componentize } from '@chaynabors/componentize-js'; + +mkdirSync('dist', { recursive: true }); + +const witDir = resolve(import.meta.dirname, '..', 'wit'); + +// 1. Bundle: resolve all imports into a single ESM file. +await build({ + entryPoints: ['entry.ts'], + bundle: true, + format: 'esm', + platform: 'browser', + target: 'es2022', + define: { 'import.meta.vitest': 'undefined' }, + external: [ + '@modelcontextprotocol/sdk/client/sse.js', + '@modelcontextprotocol/sdk/client/stdio.js', + 'child_process', + 'fs', + 'node:*', + 'path', + 'strands:*', + ], + outfile: 'dist/bundle.js', + logLevel: 'info', +}); + +// 2. Componentize: compile the bundle into a WASM component. +const source = readFileSync('dist/bundle.js', 'utf-8'); +const { component } = await componentize(source, { + witPath: witDir, + worldName: 'agent', +}); +writeFileSync('dist/strands-agent.wasm', component); + +console.log('\n✓ strands-ts-wasm/dist/strands-agent.wasm'); diff --git a/strands-wasm/docs/feature-development.md b/strands-wasm/docs/feature-development.md new file mode 100644 index 0000000000..468df3cac9 --- /dev/null +++ b/strands-wasm/docs/feature-development.md @@ -0,0 +1,261 @@ +# WASM Feature Development Guide + +Follow this guide when developing new features or modifying existing implementations across the WASM bridge. Changes that cross the WASM boundary touch multiple files across layers. + +For general development standards (conventional commits, test coverage, formatting, linting, TSDoc), see [CONTRIBUTING.md](../../CONTRIBUTING.md). For the SDK's compatibility policy on non-breaking changes (union type extensions, getter/setter conversions), see [COMPATIBILITY.MD](../../COMPATIBILITY.MD). + +## File ownership + +Know which file owns which concern. Read the relevant files before modifying them. + +| File | Owns | When to modify | +|---|---|---| +| `wit/agent.wit` | Boundary types and contract between guest and host | Adding new config fields, new WIT records, new resource methods, or new import/export interfaces | +| `strands-wasm/entry.ts` | Config deserialization, TS SDK instantiation, event mapping | Changing how config is read from WIT and passed to TS SDK constructors, adding new `createXxx()` functions, modifying stream event mapping | +| `strands-py-wasm/strands/_wasm_host.py` | Config serialization (Python → WIT records), WASM runtime management (`WasmAgent`), raw wasmtime `Variant` → Python `StreamEvent` dataclass conversion | Adding `_build_xxx()` serialization functions, modifying `WasmAgent` methods, changing how raw wasmtime variants are converted to `StreamEvent` dataclasses | +| `strands-py-wasm/strands/agent/__init__.py` | Python user-facing API, config extraction from Python objects to dicts | Adding/modifying constructor parameters, extracting config from Python class instances | +| `strands-py-wasm/strands/_conversions.py` | `StreamEvent` dataclass → Python SDK dict format, TS SDK message format → Python SDK message format | Modifying how `StreamEvent` dataclasses are converted to dicts (`event_to_dict`), how TS messages are converted to Python format (`convert_message`), or how lifecycle events are mapped to hook events (`lifecycle_event_from_wit`) | + +### Files you must not edit manually + +| File | Why | +|---|---| +| `strands-py-wasm/strands/_generated/types.py` | Auto-generated from `wit/agent.wit` by `strands-py-wasm/scripts/generate_types.py`. Regenerate with `npm run dev -- generate`. | +| `strands-wasm/generated/` | Auto-generated WIT type bindings. Regenerate with `npm run dev -- generate`. | +| `strands-wasm/build.js` | Build pipeline script. Rarely needs changes unless adding a new esbuild plugin or changing the componentize step. | +| `strands-wasm/patches/getChunkedStream.js` | WASI buffer reuse workaround. Only modify if fixing the specific componentize-js buffering bug it addresses. | + +## Naming conventions across layers + +Each layer uses a different case convention. Use the correct case for the layer you are writing in. + +| Layer | Convention | Example | +|---|---|---| +| WIT (`wit/agent.wit`) | `kebab-case` | `window-size`, `should-truncate-results` | +| TS (`strands-wasm/entry.ts`) | `camelCase` | `windowSize`, `shouldTruncateResults` | +| Python (`strands-py-wasm/`) | `snake_case` | `window_size`, `should_truncate_results` | + +componentize-js translates WIT `kebab-case` to JS `camelCase` automatically. When `entry.ts` reads `cmConfig.windowSize`, it is accessing the WIT field `window-size`. Do not convert manually in `entry.ts`. + +wasmtime-py does **not** translate automatically. Use `kebab-case` keys directly when building or reading WIT records in `_wasm_host.py`: + +```python +_rec(**{"window-size": 40, "should-truncate-results": True}) +``` + +Reading a WIT record returned from the guest: + +```python +getattr(rec, "window-size") +``` + +## Decision: where does the feature run? + +Answer these questions before writing any code. Read the TS SDK implementation of the feature first. + +**Does the feature need to execute Python user code at runtime?** (e.g., calling a Python function when the model requests a tool) +- Yes → Needs a WIT **import** interface. The guest calls back to the host. See `tool-provider` in `wit/agent.wit` for the pattern. +- No → Feature runs entirely in the WASM guest. + +**Is the feature configured once at construction, or invoked at runtime?** +- Construction → **Config holder pattern.** Python class stores config, serialized through WIT, TS instantiates the real implementation. See conversation manager for the pattern. +- Runtime → Needs WIT **export** methods on the `agent` or `response-stream` resource. See `get-messages`, `set-messages` for the pattern. + +**Is the feature a Plugin in the TS SDK?** +- Yes → Pass via the appropriate Agent constructor field (`conversationManager`, `plugins`, `sessionManager`). The TS `PluginRegistry` calls `initAgent()` automatically. Do **not** register the Python config holder as a hook provider. +- No → Wire directly in the `AgentImpl` constructor in `entry.ts`. + +## Workflow: adding a new feature + +Follow these steps in order. Each step includes a verification checkpoint. + +### Step 1: Read the TS SDK implementation + +Read the TS source files for the feature. Identify: +- What config does the constructor accept? (types, defaults, required vs optional) +- What runtime behavior does it have? (hooks registered, methods called, events emitted) +- Is it a Plugin? (extends `Plugin`, has `initAgent()`) +- What does the public API look like for TS users? + +Do not proceed until you can answer all four questions from the code you read. + +### Step 2: Update the WIT contract + +Read `wit/agent.wit` in full before modifying it. Add the new record(s) and/or fields. + +**Pattern: flat record with string discriminator.** When a feature has multiple strategies (like conversation manager), use a flat record with a `strategy: string` field rather than a WIT `variant`. This works around a wasmtime-py limitation where `option` types are not properly supported. + +```wit +record my-feature-config { + strategy: string, + field-a: s32, + field-b: option, +} +``` + +**Pattern: adding a field to `agent-config`.** Add the new config as `option` to the `agent-config` record. + +**Extending existing WIT variants.** Adding a new variant case to an existing WIT `variant` type (e.g., a new model provider to `model-config`, or a new tag to `stream-event`) is a non-breaking change per the project's [compatibility policy](../../COMPATIBILITY.MD). Existing host code that pattern-matches on known tags will ignore the new tag. Do not add backwards-compatibility shims for new variant cases. + +**Regenerate types** after updating `wit/agent.wit`: run `npm run dev -- generate`. This updates `strands-wasm/generated/` and `strands-py-wasm/strands/_generated/types.py` to match the new contract. + +**Verification:** Run `npm run dev -- validate wit`. Fix any compile errors in downstream layers before proceeding. + +### Step 3: Update `strands-wasm/entry.ts` + +Read `entry.ts` in full before modifying it. Add imports for the TS SDK classes you will instantiate to the top-level import block. All imports must be at the top of the file. Then add a `createXxx()` function that: +1. Reads the config from `(config as any).myField` (the `as any` cast is necessary because WIT-generated `AgentConfig` types may not include new fields until regenerated) +2. Returns `undefined` when no config is provided, letting the TS `Agent` constructor apply its own default +3. Instantiates the real TS SDK class with the config values +4. Returns the proper TS SDK type (not `any`) + +```typescript +function createMyFeature(config: AgentConfig): MyFeatureClass | undefined { + const cfg = (config as any).myField + if (!cfg) { + return undefined + } + return new MyFeatureClass({ + fieldA: cfg.fieldA, + fieldB: cfg.fieldB ?? undefined, + }) +} +``` + +Use `?? undefined` for WIT `option` fields. The componentize-js runtime passes `undefined` for absent options, but `null` can appear in some edge cases. The `??` operator normalizes both to `undefined`. + +Pass the result to the `Agent` constructor in `AgentImpl`. + +**Do not duplicate TS SDK defaults.** If the TS SDK constructor defaults `fieldA` to `40`, do not also hardcode `40` in `entry.ts`. Return `undefined` and let the TS SDK apply its own default. + +**Verification:** Run `npm run dev -- validate wasm`. Ensure the WASM component builds. + +### Step 4: Update the Python host + +Read each file in full before modifying it. + +**`strands-py-wasm/strands/_wasm_host.py`** — Add a `_build_xxx_variant()` function that serializes a Python config dict to a WIT record. Add the parameter to `_build_agent_config()` and `WasmAgent.__init__()`. + +```python +def _build_my_feature_variant(config: dict[str, typing.Any] | None) -> Record | None: + if config is None: + return None + return _rec( + strategy=config["type"], + **{ + "field-a": config.get("field_a"), + "field-b": config.get("field_b"), + }, + ) +``` + +Pass through values the user provided. Do not insert defaults here — let the TS SDK apply its own defaults for absent fields. + +**`strands-py-wasm/strands/agent/__init__.py`** — Add the parameter to `Agent.__init__()` with a proper type hint. Add config extraction logic that inspects the instance type and builds a config dict. Always include a `dict` passthrough and an `else` warning for unknown types. + +```python +feat_config: dict[str, Any] | None = None +if my_feature is not None: + from strands.agent.my_feature import MyFeatureA as _A, MyFeatureB as _B + + if isinstance(my_feature, _A): + feat_config = {"type": "strategy-a", "field_a": my_feature.field_a} + elif isinstance(my_feature, _B): + feat_config = {"type": "strategy-b", "field_b": my_feature.field_b} + elif isinstance(my_feature, dict): + feat_config = my_feature + else: + log.warning("unknown my_feature type: %s, ignoring", type(my_feature).__name__) +``` + +**Feature module** (e.g., `strands-py-wasm/strands/agent/my_feature/`) — Create config holder classes that store user-provided config and nothing else. They extend `HookProvider` for type compatibility with the `Agent` constructor, but must **not** register any hooks. Hook registration happens in the TS SDK's `initAgent()` inside the WASM guest. + +```python +class MyFeatureManager(HookProvider): + def __init__(self, field_a: int = 40, field_b: str | None = None) -> None: + self.field_a = field_a + self.field_b = field_b +``` + +**Verification:** Run `python -m pytest strands-py-wasm/tests_unit/` to validate serialization. + +### Step 5: Write tests + +The project requires 80% test coverage (see [CONTRIBUTING.md](../../CONTRIBUTING.md)). + +**Unit tests** (`strands-py-wasm/tests_unit/`): Test the serialization boundary. Verify that config holder classes store the right values, that `_build_xxx_variant()` produces correct WIT records, and that edge cases (missing fields, invalid values) are handled. + +**Integration tests** (`strands-py-wasm/tests_integ/`): Test end-to-end behavior. Create an agent with the feature configured, invoke it, and verify observable behavior. Do **not** test by calling internal methods on config holder classes — the implementation runs in the TS guest, so test through the agent's public API. + +### Step 6: Document the change + +**`strands-wasm/docs/python-api-changes.md`** — For each Python API change, document: +1. The TS SDK design (with code) +2. The WASM bridge implementation +3. The Python API (before/after code snippets) +4. How the functionality is preserved if the API surface differs from the standalone Python SDK + +**`AGENTS.md`** — If the change adds new directories, files, or significantly restructures existing modules, update the directory structure section in [AGENTS.md](../../AGENTS.md). + +## Workflow: modifying an existing bridged feature + +Modifications (adding a parameter, fixing a bug, changing a default) are more common than new features. Discover the full data flow before changing anything. + +### Step 1: Trace the data flow + +Grep for the feature across all layers to find every file involved: + +```bash +grep -rn 'feature_name\|featureName\|feature-name' wit/ strands-wasm/entry.ts strands-py-wasm/strands/ +``` + +Read every file that appears in the results. Trace the full path: Python construction → WIT serialization → TS instantiation. Identify every function, record, and field involved before making changes. + +### Step 2: Identify the change scope + +Determine which layers your change affects: + +- **Adding a config parameter**: All layers change (WIT record, `entry.ts` reader, `_wasm_host.py` serializer, `agent/__init__.py` extractor, config holder class, tests). +- **Changing a default value**: Usually only the layer that owns the default. If the WASM bridge delegates to the TS SDK default (returns `undefined`), changes to the TS SDK default propagate automatically. If the bridge hardcodes a default, it must be updated. +- **Fixing a serialization bug**: Usually `_wasm_host.py` (Python → WIT) or `_conversions.py` (WIT → Python), plus tests. +- **Fixing a type mismatch**: May involve multiple layers. Trace the type from Python through WIT to TS to find where the mismatch originates. + +### Step 3: Make changes in dependency order + +Changes cascade through the pipeline. Make changes in this order so each layer compiles against the updated layer above it: + +1. `wit/agent.wit` (if the contract changes) +2. Regenerate types: `npm run dev -- generate` +3. `strands-wasm/entry.ts` +4. `strands-py-wasm/strands/_wasm_host.py` +5. `strands-py-wasm/strands/agent/__init__.py` and feature modules +6. Tests + +### Step 4: Verify at each layer + +After modifying each layer, run the appropriate validation: + +| Layer changed | Validation command | +|---|---| +| `wit/agent.wit` | `npm run dev -- validate wit` | +| `strands-wasm/entry.ts` | `npm run dev -- validate wasm` | +| `strands-py-wasm/` | `python -m pytest strands-py-wasm/tests_unit/` | +| All layers | `npm run dev -- ci` | + +## Common pitfalls + +**Read before you write.** Always read a file before modifying it. Do not assume what a function signature, WIT record, or config dict looks like. The codebase changes across PRs. Stale assumptions cause incorrect edits. + +**Do not duplicate TS SDK defaults.** If the TS SDK defaults `windowSize` to `40`, do not hardcode `40` in `entry.ts` or `_wasm_host.py`. Return `undefined` and let the TS SDK own its defaults. Hardcoded values silently diverge when the TS SDK changes. + +**Do not register hooks in Python config holders.** Config holder classes extend `HookProvider` for type compatibility only. All hook registration happens in the TS SDK's `initAgent()` inside the WASM guest. Registering hooks on the Python side creates duplicate behavior. + +**Do not edit generated files.** `strands-py-wasm/strands/_generated/types.py` and `strands-wasm/generated/` are auto-generated. Edits are overwritten on the next `npm run dev -- generate`. + +**Separate formatting from feature changes.** Keep formatting (Prettier, ruff) in separate commits or PRs. Mixed diffs obscure functional changes. + +**Update `_conversions.py` for return-path changes.** Data returning from the WASM guest (messages, stream events) passes through `_conversions.py`, not `_wasm_host.py`. If the TS SDK changes message format or event types, update `_conversions.py`. + +**Keep serialization types explicit.** If a Python constructor accepts `dict[str, Any]` but the serialized form is `str` (JSON), store the user-provided type on the class and serialize in a dedicated method at the bridge boundary. Do not silently convert types in the constructor. + +**Set all WIT record fields.** When using the flat record pattern with a strategy discriminator, every field must be present in every record instance, even if unused. wasmtime-py requires all fields of a record to be set. Use zero values or `None` for unused fields. diff --git a/strands-wasm/docs/python-api-changes.md b/strands-wasm/docs/python-api-changes.md new file mode 100644 index 0000000000..2517ec67fb --- /dev/null +++ b/strands-wasm/docs/python-api-changes.md @@ -0,0 +1,345 @@ +# Python API Changes + +Tracks all Python SDK API changes that result from the WASM bridge architecture. Each feature section documents the TypeScript SDK design, the WASM bridge implementation, and the resulting Python API change with code evidence. + +--- + +## Conversation Manager + +The Python conversation manager classes are config holders. The actual implementation runs inside the TypeScript SDK WASM guest. + +### 1. Conversation manager is not accessible after construction + +**TS design:** The agent stores the conversation manager as a private field. + +```typescript +// strands-ts/src/agent/agent.ts:191 +private readonly _conversationManager: ConversationManager +``` + +There is no public getter. Users configure it at construction and never access it again. + +**WASM bridge:** The config is serialized through the WIT contract during agent construction. No handle to the TS conversation manager instance is retained on the Python side. + +**Python API change:** + +```python +# Standalone Python SDK (1.x) — worked +agent = Agent(conversation_manager=SlidingWindowConversationManager()) +agent.conversation_manager # accessible + +# WASM bridged Python SDK (2.x) — not available +agent = Agent(conversation_manager=SlidingWindowConversationManager()) +agent.conversation_manager # AttributeError +``` + +Not needed. The conversation manager operates automatically via hooks registered during `initAgent()`. + +### 2. No manual `reduce_context()` or `apply_management()` + +**TS design:** Context reduction is hook driven. The base class registers an `AfterModelCallEvent` callback that catches overflow errors and calls `reduce()` automatically. + +```typescript +// strands-ts/src/conversation-manager/conversation-manager.ts:100-108 +initAgent(agent: LocalAgent): void { + agent.addHook(AfterModelCallEvent, async (event) => { + if (event.error instanceof ContextWindowOverflowError) { + if (await this.reduce({ agent: event.agent, model: event.model, error: event.error })) { + event.retry = true + } + } + }) + } +``` + +`SlidingWindowConversationManager` adds proactive trimming via a second hook: + +```typescript +// strands-ts/src/conversation-manager/sliding-window-conversation-manager.ts:72-78 +public override initAgent(agent: LocalAgent): void { + super.initAgent(agent) + agent.addHook(AfterInvocationEvent, (event) => { + this._applyManagement(event.agent.messages) + }) + } +``` + +There are no public methods to trigger these manually. The hooks system is the invocation mechanism. + +**WASM bridge:** `createConversationManager()` in `strands-wasm/entry.ts` instantiates the real TS class. The TS `Agent` constructor adds it to `PluginRegistry`, which calls `initAgent()`. Both hooks are registered inside the WASM guest. + +**Python API change:** + +```python +# Standalone Python SDK (1.x) — worked +cm = SlidingWindowConversationManager() +agent = Agent(conversation_manager=cm) +cm.reduce_context(agent) # manually trigger reduction +cm.apply_management(agent) # manually trigger window trimming + +# WASM bridged Python SDK (2.x) — not available +cm = SlidingWindowConversationManager() +agent = Agent(conversation_manager=cm) +cm.reduce_context(agent) # AttributeError — no such method +cm.apply_management(agent) # AttributeError — no such method +``` + +Not needed. Overflow recovery fires automatically on `ContextWindowOverflowError`. Proactive trimming fires automatically after every invocation when messages exceed `windowSize`. + +### 3. Summarization accepts a model config, not an agent + +**TS design:** `SummarizingConversationManager` accepts a `model`, not an agent. Summarization calls the model directly. + +```typescript +// strands-ts/src/conversation-manager/summarizing-conversation-manager.ts:46-51 +export type SummarizingConversationManagerConfig = { + model?: Model + summaryRatio?: number + preserveRecentMessages?: number + summarizationSystemPrompt?: string +} +``` + +```typescript +// strands-ts/src/conversation-manager/summarizing-conversation-manager.ts:157-160 +private async _generateSummary(messagesToSummarize: Message[], model: Model): Promise { + // ... + const stream = model.streamAggregated(summarizationMessages, { + systemPrompt: this._summarizationSystemPrompt, + }) +``` + +**WASM bridge:** The Python user provides a model config dict. `createConversationManager()` in `strands-wasm/entry.ts` parses the JSON and calls `createModel()` to instantiate a TS model: + +```typescript +// strands-wasm/entry.ts:427-430 +if (cmConfig.summarizationModelConfig) { + const parsed = JSON.parse(cmConfig.summarizationModelConfig) + summaryModel = createModel(parsed) +} +``` + +**Python API change:** + +```python +# Standalone Python SDK (1.x) — accepted a full Agent instance +summarizer = Agent(model=some_model, system_prompt="Summarize.") +agent = Agent(conversation_manager=SummarizingConversationManager( + summarization_agent=summarizer, +)) + +# WASM bridged Python SDK (2.x) — accepts a model config dict +agent = Agent(conversation_manager=SummarizingConversationManager( + summarization_model_config={ + "provider": "bedrock", + "model_id": "us.anthropic.claude-3-haiku-20240307-v1:0", + }, +)) +``` + +The WASM boundary cannot serialize a live `Agent` instance. The model config dict is instantiated as a TS model inside the guest, which matches the TS SDK's design of calling the model directly rather than re-entering the agent loop. + +### 4. `per_turn` parameter not supported + +**TS design:** `SlidingWindowConversationManager` does not implement `per_turn`. Proactive trimming runs unconditionally after every invocation via the `AfterInvocationEvent` hook when messages exceed `windowSize`. + +**Python API change:** + +```python +# Standalone Python SDK (1.x) — worked +agent = Agent(conversation_manager=SlidingWindowConversationManager(per_turn=3)) + +# WASM bridged Python SDK (2.x) — not supported +agent = Agent(conversation_manager=SlidingWindowConversationManager(per_turn=3)) +# per_turn is silently ignored (caught by **_kwargs) +``` + +The TS SDK trims after every invocation when the window is exceeded, which is equivalent to `per_turn=True`. + +### 5. Session state methods not available + +**TS design:** The TS SDK has its own session management system. Conversation manager state persistence (`_summary_message`, `removed_message_count`) is not part of the `ConversationManager` interface. + +**Python API change:** + +```python +# Standalone Python SDK (1.x) — worked +state = cm.get_state() +cm.restore_from_session(state) +cm.removed_message_count + +# WASM bridged Python SDK (2.x) — not available +``` + +--- + +## WIT Contract + +The `conversation-manager-config` uses a flat record with a string `strategy` discriminator (`"none"`, `"sliding-window"`, `"summarizing"`) rather than a WIT variant. This works around a wasmtime-py limitation where `option` types are not properly supported. + +```wit +record conversation-manager-config { + strategy: string, + window-size: s32, + should-truncate-results: bool, + summary-ratio: option, + preserve-recent-messages: option, + summarization-system-prompt: option, + summarization-model-config: option, +} +``` + +Fields irrelevant to the selected strategy are set to zero values or `None`. + +--- + +## Python Config Reference + +### `NullConversationManager` + +No parameters. Disables conversation management. Overflow errors propagate uncaught. + +### `SlidingWindowConversationManager` + +| Parameter | Type | Default | TS equivalent | +|---|---|---|---| +| `window_size` | `int` | `40` | `windowSize` | +| `should_truncate_results` | `bool` | `True` | `shouldTruncateResults` | + +### `SummarizingConversationManager` + +| Parameter | Type | Default | TS equivalent | +|---|---|---|---| +| `summary_ratio` | `float` | `0.3` | `summaryRatio` (clamped 0.1 to 0.8) | +| `preserve_recent_messages` | `int` | `10` | `preserveRecentMessages` | +| `summarization_system_prompt` | `str \| None` | `None` | `summarizationSystemPrompt` | +| `summarization_model_config` | `dict \| None` | `None` | Serialized to JSON, parsed by TS guest, passed to `createModel()` to produce a `Model` for `config.model` | + +Model config dict format: + +```python +{ + "provider": "bedrock", # "bedrock", "anthropic", "openai", or "gemini" + "model_id": "us.anthropic.claude-3-haiku-20240307-v1:0", + "region": "us-west-2", # bedrock only + "api_key": "...", # anthropic, openai, gemini only +} +``` + +--- + +## Structured Output + +### Overview + +Structured output validation runs inside the TypeScript SDK WASM guest using `StructuredOutputTool`. Python sends a flattened JSON schema through the WIT contract. The TS agent loop registers the tool, handles force-retry when the model doesn't call it, validates with Zod, and returns the validated JSON on the stop event. Python instantiates the Pydantic model from the validated JSON. + +### 1. Parameter name changed + +**TS design:** The TS Agent accepts `structuredOutputSchema` (a Zod schema) on `AgentConfig` and `InvokeOptions`. + +```typescript +// strands-ts/src/types/agent.ts:165 +structuredOutputSchema?: z.ZodSchema +``` + +**WASM bridge:** Python sends the JSON schema as a string through `structured-output-schema` on `agent-config` and `stream-args`. `entry.ts` reconstructs a Zod schema via `z.fromJSONSchema()`. + +**Python API change:** + +```python +# Standalone Python SDK (1.x) +agent = Agent(structured_output_model=MyModel) +result = agent("prompt", structured_output_model=MyModel) + +# WASM bridged Python SDK (2.x) +agent = Agent(structured_output=MyModel) +result = agent("prompt", structured_output=MyModel) +``` + +### 2. Tool name is fixed + +**TS design:** `StructuredOutputTool` uses a fixed name defined in `strands-ts/src/tools/structured-output-tool.ts:9`: + +```typescript +export const STRUCTURED_OUTPUT_TOOL_NAME = 'strands_structured_output' +``` + +**Python API change:** + +```python +# Standalone Python SDK (1.x) — tool name was the Pydantic model class name +# Conversation history shows: toolUse name="MyModel" + +# WASM bridged Python SDK (2.x) — fixed tool name +# Conversation history shows: toolUse name="strands_structured_output" +``` + +This does not affect user code. The tool name appears in conversation history but users do not reference it directly. + +### 3. Force-retry uses toolChoice, not a user message + +**TS design:** When the model stops without calling the structured output tool, the TS agent loop sets `toolChoice: { tool: { name: 'strands_structured_output' } }` and continues the loop (`strands-ts/src/agent/agent.ts:771`). No user message is appended. + +**Python API change:** + +```python +# Standalone Python SDK (1.x) — customizable force prompt +agent = Agent() +result = agent("prompt", structured_output_model=MyModel, structured_output_prompt="Format as MyModel now.") + +# WASM bridged Python SDK (2.x) — structured_output_prompt not available +agent = Agent() +result = agent("prompt", structured_output=MyModel) +# Force-retry handled automatically by TS SDK via toolChoice +``` + +The TS mechanism is more reliable because it forces the model to call the specific tool rather than relying on a text prompt. + +### 4. Validation runs in Zod, instantiation in Pydantic + +**TS design:** `StructuredOutputTool.stream()` validates via `this._schema.parse(toolUse.input)` (Zod) at `strands-ts/src/tools/structured-output-tool.ts:59`. On success, returns a `JsonBlock` with the validated data. On error, returns the error message for LLM retry. + +**WASM bridge:** The validated JSON returns through `stop-data.structured-output` as a JSON string. Python instantiates the Pydantic model from it: `so_model(**json.loads(json_str))`. + +**Python API change:** + +```python +# Both versions — same user experience +result = agent("Describe a person", structured_output=Person) +result.structured_output # Person(name="Alice", age=30) — Pydantic instance +``` + +Validation errors from Zod trigger model retry inside the TS guest (the model sees the error message and corrects its output). Custom Pydantic `@field_validator` logic that goes beyond JSON schema constraints cannot be enforced by Zod. Zod validates schema structure; Pydantic adds business logic on instantiation. + +### 5. Deprecated method removed + +**TS design:** No equivalent to `agent.structured_output(model, prompt)`. Structured output is configured via the constructor or per-invocation options. + +**Python API change:** + +```python +# Standalone Python SDK (1.x) — deprecated method +result = agent.structured_output(MyModel, "prompt") + +# WASM bridged Python SDK (2.x) — use standard invocation +result = agent("prompt", structured_output=MyModel) +# Or via helper method: +result = agent.structured_output(MyModel, "prompt") # still available, delegates to above +``` + +### 6. Error on force failure + +**TS design:** Throws `StructuredOutputError` (`strands-ts/src/errors.ts:203`) with message `"The model failed to invoke the structured output tool even after it was forced."`. + +**Python API change:** + +```python +# Standalone Python SDK (1.x) +from strands.types.exceptions import StructuredOutputException + +# WASM bridged Python SDK (2.x) +from strands.types.exceptions import StructuredOutputError +``` + +The exception is raised when the TS SDK's error message is detected in the stream. diff --git a/strands-wasm/entry.ts b/strands-wasm/entry.ts new file mode 100644 index 0000000000..0cc9ea882b --- /dev/null +++ b/strands-wasm/entry.ts @@ -0,0 +1,706 @@ +/** + * WASM component exporting strands:agent/api. + * + * The Agent resource holds a TS SDK Agent instance across multiple + * generate() calls. Each generate() returns a response-stream whose + * events() method yields the typed WIT stream-event. Consumers drain + * the ReadableStream to completion; componentize-js turns that into + * the component-model `stream` on the wire. + */ + +/// +/// +/// +/// +/// +/// +/// +/// + +import type { AgentConfig, InvokeArgs, RespondArgs, AgentError } from 'strands:agent/api@0.1.0' +import type { Message as WitMessage, PromptInput } from 'strands:agent/messages@0.1.0' +import type { + StreamEvent as WitStreamEvent, + StopEvent as WitStopEvent, + StopReason as WitStopReason, + AgentTrace as WitAgentTrace, + AgentMetrics as WitAgentMetrics, +} from 'strands:agent/streaming@0.1.0' +import type { ModelConfig as WitModelConfig, ModelParams as WitModelParams } from 'strands:agent/models@0.1.0' +import type { ToolSpec, ToolChoice as WitToolChoice } from 'strands:agent/tools@0.1.0' + +import { callTool } from 'strands:agent/tool-provider@0.1.0' +import { Agent, FunctionTool, SessionManager, FileStorage } from '@strands-agents/sdk' +import { S3Storage } from '@strands-agents/sdk/session/s3-storage' +import { AnthropicModel } from '@strands-agents/sdk/models/anthropic' +import { BedrockModel } from '@strands-agents/sdk/models/bedrock' +import { OpenAIModel } from '@strands-agents/sdk/models/openai' +import { GoogleModel } from '@strands-agents/sdk/models/google' +import type { + StopReason, + AgentStreamEvent, + Model, + BaseModelConfig, + Usage, + Metrics, + AgentResult, + ToolContext, + SystemPrompt, + InvokeArgs as SdkInvokeArgs, + Message, + StreamOptions, + ToolChoice, + ModelStreamEvent, + ContentBlock, + SaveLatestStrategy, + JSONValue, +} from '@strands-agents/sdk' +import { + ConversationManager, + NullConversationManager, + SlidingWindowConversationManager, + SummarizingConversationManager, +} from '@strands-agents/sdk' +import { z } from 'zod' + +// +// --- logging + error helpers -------------------------------------------- +// + +function errorMessage(err: unknown): string { + return err instanceof Error ? err.message : String(err) +} + +/** Wrap a throwable promise as a typed `agent-error` result. */ +async function asAgentResult(fn: () => Promise, storageErrorWrap = false): Promise<{ tag: 'ok'; val: T } | { tag: 'err'; val: AgentError }> { + try { + return { tag: 'ok', val: await fn() } + } catch (err) { + const msg = errorMessage(err) + if (storageErrorWrap) { + return { tag: 'err', val: { tag: 'storage', val: { tag: 'permanent', val: msg } } } + } + return { tag: 'err', val: { tag: 'internal', val: msg } } + } +} + +// +// --- small shape maps --------------------------------------------------- +// + +const STOP_REASON_MAP: Record = { + endTurn: 'end-turn', + toolUse: 'tool-use', + maxTokens: 'max-tokens', + contentFiltered: 'content-filtered', + guardrailIntervened: 'guardrail-intervened', + stopSequence: 'stop-sequence', + modelContextWindowExceeded: 'model-context-window-exceeded', + cancelled: 'cancelled', +} as unknown as Record + +function mapStopReason(reason: StopReason): WitStopReason { + return STOP_REASON_MAP[reason] ?? 'error' +} + +function mapUsage(src: Partial | null | undefined): WitStopEvent['usage'] { + if (src == null) return undefined + return { + inputTokens: src.inputTokens ?? 0, + outputTokens: src.outputTokens ?? 0, + totalTokens: src.totalTokens ?? (src.inputTokens ?? 0) + (src.outputTokens ?? 0), + cacheReadInputTokens: src.cacheReadInputTokens, + cacheWriteInputTokens: src.cacheWriteInputTokens, + } +} + +function mapMetrics(src: Partial | null | undefined): WitStopEvent['metrics'] { + if (src == null) return undefined + return { latencyMs: typeof src.latencyMs === 'number' ? src.latencyMs : 0 } +} + +/** Serialize a TS SDK Message to the WIT shape. */ +function mapMessage(message: Message): WitMessage { + return { + role: message.role, + content: message.content.map(mapContentBlock), + metadata: message.metadata + ? (JSON.parse(JSON.stringify(message.metadata)) as WitMessage['metadata']) + : undefined, + } as WitMessage +} + +/** Serialize a TS SDK ContentBlock to the WIT tagged-variant shape. */ +function mapContentBlock(block: ContentBlock): import('strands:agent/messages@0.1.0').ContentBlock { + type WitBlock = import('strands:agent/messages@0.1.0').ContentBlock + // block.type is the SDK class discriminator; toJSON drops class identity but keeps fields. + const payload = JSON.parse(JSON.stringify(block)) + switch (block.type) { + case 'textBlock': return { tag: 'text', val: payload } as WitBlock + case 'toolUseBlock': return { tag: 'tool-use', val: payload } as WitBlock + case 'toolResultBlock': return { tag: 'tool-result', val: payload } as WitBlock + case 'reasoningBlock': return { tag: 'reasoning', val: payload } as WitBlock + case 'cachePointBlock': return { tag: 'cache-point', val: payload } as WitBlock + case 'imageBlock': return { tag: 'image', val: payload } as WitBlock + case 'videoBlock': return { tag: 'video', val: payload } as WitBlock + case 'documentBlock': return { tag: 'document', val: payload } as WitBlock + case 'citationsBlock': return { tag: 'citations', val: payload } as WitBlock + case 'guardContentBlock': return { tag: 'guard-content', val: payload } as WitBlock + default: { + block satisfies never + throw new Error(`unknown content block: ${(block as { type: string }).type}`) + } + } +} + +// +// --- stream event mapping ------------------------------------------------ +// + +/** + * Translate a TS SDK `AgentStreamEvent` to its WIT counterpart. Returns + * `null` for events whose data is available through another arm (e.g. + * the terminal `AgentResultEvent`, which is surfaced via `stop`). See + * docs/BRIDGE-COLLAPSE-PLAN.md for the plan to delete this function. + */ +function mapEvent(event: AgentStreamEvent): WitStreamEvent | null { + switch (event.type) { + case 'beforeInvocationEvent': + return { tag: 'before-invocation', val: { invocationState: '{}' } } + case 'afterInvocationEvent': + return { tag: 'after-invocation', val: { invocationState: '{}' } } + case 'messageAddedEvent': + return { tag: 'message-added', val: { message: mapMessage(event.message) } } + case 'beforeModelCallEvent': + return { tag: 'before-model-call', val: { projectedInputTokens: undefined } } + case 'afterModelCallEvent': + return { + tag: 'after-model-call', + val: { + attemptCount: 1, + stopData: event.stopData + ? { + message: mapMessage(event.stopData.message), + stopReason: mapStopReason(event.stopData.stopReason), + redaction: event.stopData.redaction ? { userMessage: event.stopData.redaction.userMessage } : undefined, + } + : undefined, + error: event.error ? { tag: 'internal', val: event.error.message } : undefined, + }, + } + case 'beforeToolsEvent': + return { tag: 'before-tools', val: { message: mapMessage(event.message) } } + case 'afterToolsEvent': + return { tag: 'after-tools', val: { message: mapMessage(event.message) } } + case 'beforeToolCallEvent': + return { + tag: 'before-tool-call', + val: { + toolUse: { + name: event.toolUse.name, + toolUseId: event.toolUse.toolUseId, + input: JSON.stringify(event.toolUse.input ?? {}), + }, + }, + } + case 'afterToolCallEvent': + return { + tag: 'after-tool-call', + val: { + toolUse: { + name: event.toolUse.name, + toolUseId: event.toolUse.toolUseId, + input: JSON.stringify(event.toolUse.input ?? {}), + }, + toolResult: mapContentBlock(event.result) as unknown as import('strands:agent/messages@0.1.0').ToolResultBlock, + error: event.error ? { tag: 'execution-failed', val: event.error.message } : undefined, + }, + } + case 'contentBlockEvent': + return { tag: 'content-block', val: { contentBlock: mapContentBlock(event.contentBlock) } } + case 'modelMessageEvent': + return { + tag: 'model-message', + val: { message: mapMessage(event.message), stopReason: mapStopReason(event.stopReason) }, + } + case 'toolResultEvent': + return { + tag: 'tool-result-hook', + val: { toolResult: mapContentBlock(event.result) as unknown as import('strands:agent/messages@0.1.0').ToolResultBlock }, + } + case 'toolStreamUpdateEvent': + return { tag: 'tool-update', val: { data: JSON.stringify(event.event.data ?? null) } } + case 'modelStreamUpdateEvent': + return { tag: 'model-update', val: { event: JSON.stringify(event.event) } } + case 'agentResultEvent': + // The terminal `stop` arm carries this data instead. + return null + case 'interruptEvent': + return { + tag: 'interrupt', + val: { + id: event.interrupt.id, + name: event.interrupt.name, + reason: + event.interrupt.reason !== undefined + ? typeof event.interrupt.reason === 'string' + ? event.interrupt.reason + : JSON.stringify(event.interrupt.reason) + : undefined, + }, + } + default: { + event satisfies never + return null + } + } +} + +function mapStopEvent(result: AgentResult): WitStreamEvent { + return { + tag: 'stop', + val: { + reason: mapStopReason(result.stopReason), + usage: mapUsage(result.metrics?.accumulatedUsage), + metrics: mapMetrics(result.metrics?.accumulatedMetrics), + structuredOutput: result.structuredOutput !== undefined ? JSON.stringify(result.structuredOutput) : undefined, + }, + } +} + +// +// --- config builders ----------------------------------------------------- +// + +function modelParamsConfig(params?: WitModelParams): Record { + if (!params) return {} + return { + ...(params.maxTokens != null ? { maxTokens: params.maxTokens } : {}), + ...(params.temperature != null ? { temperature: params.temperature } : {}), + ...(params.topP != null ? { topP: params.topP } : {}), + } +} + +function createModel(config?: WitModelConfig, params?: WitModelParams): Model { + const base = modelParamsConfig(params) + if (!config) return new BedrockModel(base) + + switch (config.tag) { + case 'anthropic': { + const extra = config.val.additionalConfig ? JSON.parse(config.val.additionalConfig) : {} + return new AnthropicModel({ + ...base, + ...(config.val.modelId ? { modelId: config.val.modelId } : {}), + ...(config.val.apiKey ? { apiKey: config.val.apiKey } : {}), + ...extra, + }) + } + case 'bedrock': { + const extra = config.val.additionalConfig ? JSON.parse(config.val.additionalConfig) : {} + const clientConfig: Record = extra.clientConfig ?? {} + if (config.val.accessKeyId && config.val.secretAccessKey) { + clientConfig.credentials = { + accessKeyId: config.val.accessKeyId, + secretAccessKey: config.val.secretAccessKey, + ...(config.val.sessionToken ? { sessionToken: config.val.sessionToken } : {}), + } + } + return new BedrockModel({ + ...base, + ...(config.val.modelId ? { modelId: config.val.modelId } : {}), + ...(config.val.region ? { region: config.val.region } : {}), + clientConfig, + ...extra, + }) + } + case 'openai': { + const extra = config.val.additionalConfig ? JSON.parse(config.val.additionalConfig) : {} + return new OpenAIModel({ + ...base, + ...(config.val.modelId ? { modelId: config.val.modelId } : {}), + ...(config.val.apiKey ? { apiKey: config.val.apiKey } : {}), + ...extra, + }) + } + case 'gemini': { + const extra = config.val.additionalConfig ? JSON.parse(config.val.additionalConfig) : {} + return new GoogleModel({ + ...base, + ...(config.val.modelId ? { modelId: config.val.modelId } : {}), + ...(config.val.apiKey ? { apiKey: config.val.apiKey } : {}), + ...extra, + }) + } + case 'custom': + // Phase 2: wire `model-provider` host interface. + throw new Error(`model-config.custom is not implemented yet (provider-id: ${config.val.providerId})`) + default: { + config satisfies never + throw new Error(`Unknown model-config arm`) + } + } +} + +/** Convert WIT ToolSpecs into TS FunctionTools that call back to the host. */ +function createTools(specs: ToolSpec[] | undefined): FunctionTool[] | undefined { + if (!specs || specs.length === 0) return undefined + + return specs.map( + (spec) => + new FunctionTool({ + name: spec.name, + description: spec.description, + inputSchema: JSON.parse(spec.inputSchema), + callback: async (input: unknown, toolContext: ToolContext) => { + const stream = callTool({ + name: spec.name, + input: JSON.stringify(input), + toolUseId: toolContext.toolUse.toolUseId, + }) + for (;;) { + const value = stream.read() + if (value === undefined) { + throw new Error(`tool ${spec.name} stream ended without complete/error`) + } + switch (value.tag) { + case 'data': + // Streaming tool progress is not surfaced to the SDK caller today. + continue + case 'complete': + return value.val as unknown as JSONValue + case 'error': + throw new Error(`tool ${spec.name} failed: ${value.val.tag}`) + } + } + }, + }) + ) +} + +function buildSystemPrompt(config: AgentConfig): SystemPrompt | undefined { + const sp = config.systemPrompt + if (!sp) return undefined + if (sp.tag === 'text') return sp.val + return sp.val as unknown as SystemPrompt +} + +function createToolChoiceProxy(baseModel: Model, toolChoice: ToolChoice): Model { + return new Proxy(baseModel, { + get(target, prop, receiver) { + if (prop === 'stream') { + return async function* (messages: Message[], options?: StreamOptions): AsyncIterable { + yield* target.stream(messages, { ...options, toolChoice }) + } + } + return Reflect.get(target, prop, receiver) + }, + }) as Model +} + +/** Project a WIT `tool-choice` variant onto the TS SDK shape. */ +function toolChoiceFromWit(tc: WitToolChoice): ToolChoice { + switch (tc.tag) { + case 'auto': + return { auto: {} } + case 'any': + return { any: {} } + case 'named': + return { tool: { name: tc.val } } + } +} + +function createSessionManager(config: AgentConfig): SessionManager | undefined { + if (!config.session) return undefined + const sc = config.session + let storage + switch (sc.storage.tag) { + case 'file': + storage = new FileStorage(sc.storage.val.baseDir) + break + case 's3': { + const s3 = sc.storage.val + storage = new S3Storage({ + bucket: s3.bucket, + ...(s3.region ? { region: s3.region } : {}), + ...(s3.prefix ? { prefix: s3.prefix } : {}), + }) + break + } + case 'custom': + // Phase 2: wire `snapshot-storage` host interface. + throw new Error(`storage-config.custom is not implemented yet (backend-id: ${sc.storage.val.backendId})`) + } + + const saveLatestOn: SaveLatestStrategy | undefined = sc.saveLatest + ? sc.saveLatest.tag === 'trigger' + ? 'trigger' + : sc.saveLatest.tag + : undefined + return new SessionManager({ + sessionId: sc.sessionId, + storage: { snapshot: storage }, + ...(saveLatestOn !== undefined ? { saveLatestOn } : {}), + }) +} + +function createConversationManager(config: AgentConfig): ConversationManager | undefined { + const cm = config.conversationManager + if (!cm) return undefined + switch (cm.tag) { + case 'none': + return new NullConversationManager() + case 'sliding-window': + return new SlidingWindowConversationManager({ + windowSize: cm.val.windowSize, + shouldTruncateResults: cm.val.shouldTruncateResults, + }) + case 'summarizing': { + const summaryModel = cm.val.summarizationModel ? createModel(cm.val.summarizationModel) : undefined + return new SummarizingConversationManager({ + model: summaryModel, + summaryRatio: cm.val.summaryRatio, + preserveRecentMessages: cm.val.preserveRecentMessages, + summarizationSystemPrompt: cm.val.summarizationSystemPrompt, + }) + } + } +} + +function parseStructuredOutputSchema(jsonStr: string | undefined): z.ZodSchema | undefined { + if (!jsonStr) return undefined + try { + return z.fromJSONSchema(JSON.parse(jsonStr)) as z.ZodSchema + } catch (e) { + throw new Error(`Invalid structured output schema: ${errorMessage(e)}`) + } +} + +function invokeInputFromWit(input: PromptInput): SdkInvokeArgs { + return input.tag === 'text' ? input.val : (input.val as unknown as SdkInvokeArgs) +} + +// +// --- resources ----------------------------------------------------------- +// + +class AgentImpl { + private agent: Agent + private defaultTools: FunctionTool[] | undefined + private sessionManager: SessionManager | undefined + + constructor(config: AgentConfig) { + const model = createModel(config.model, config.modelParams) + this.defaultTools = createTools(config.tools) + this.sessionManager = createSessionManager(config) + + this.agent = new Agent({ + model, + systemPrompt: buildSystemPrompt(config), + tools: this.defaultTools, + sessionManager: this.sessionManager, + conversationManager: createConversationManager(config), + structuredOutputSchema: parseStructuredOutputSchema(config.structuredOutputSchema), + printer: config.displayOutput ?? true, + }) + } + + generate(args: InvokeArgs): ResponseStreamImpl { + if (args.tools) { + const requestTools = createTools(args.tools) + this.agent.toolRegistry.clear() + if (requestTools) this.agent.toolRegistry.add(requestTools) + } + + let originalModel: Model | undefined + if (args.toolChoice) { + originalModel = this.agent.model + this.agent.model = createToolChoiceProxy(originalModel, toolChoiceFromWit(args.toolChoice)) + } + + const structuredOutputSchema = parseStructuredOutputSchema(args.structuredOutputSchema) + return new ResponseStreamImpl( + this.agent, + args.input, + this.defaultTools, + originalModel, + structuredOutputSchema + ) + } + + getMessages(): WitMessage[] { + return this.agent.messages.map(mapMessage) + } + + setMessages(messages: WitMessage[]): { tag: 'ok'; val: void } | { tag: 'err'; val: AgentError } { + try { + const parsed = messages.map((m) => JSON.parse(JSON.stringify(m)) as Message) + this.agent.messages.splice(0, this.agent.messages.length, ...parsed) + return { tag: 'ok', val: undefined } + } catch (err) { + return { tag: 'err', val: { tag: 'invalid-input', val: errorMessage(err) } } + } + } + + getAppState(): string { + return JSON.stringify(this.agent.appState.getAll()) + } + + setAppState(json: string): { tag: 'ok'; val: void } | { tag: 'err'; val: AgentError } { + try { + const parsed = JSON.parse(json) as Record + this.agent.appState.clear() + for (const [k, v] of Object.entries(parsed)) this.agent.appState.set(k, v) + return { tag: 'ok', val: undefined } + } catch (err) { + return { tag: 'err', val: { tag: 'invalid-input', val: errorMessage(err) } } + } + } + + getModelState(): string { + return JSON.stringify(this.agent.modelState.getAll()) + } + + setModelState(json: string): { tag: 'ok'; val: void } | { tag: 'err'; val: AgentError } { + try { + const parsed = JSON.parse(json) as Record + this.agent.modelState.clear() + for (const [k, v] of Object.entries(parsed)) this.agent.modelState.set(k, v) + return { tag: 'ok', val: undefined } + } catch (err) { + return { tag: 'err', val: { tag: 'invalid-input', val: errorMessage(err) } } + } + } + + getTraces(): WitAgentTrace[] { + // Phase 2: surface the SDK's traces here. For now return empty. + return [] + } + + getMetrics(): WitAgentMetrics { + // Phase 2: surface the SDK's metrics here. For now return zeroes. + return { + cycleCount: 0, + accumulatedUsage: { inputTokens: 0, outputTokens: 0, totalTokens: 0, cacheReadInputTokens: undefined, cacheWriteInputTokens: undefined }, + accumulatedMetrics: { latencyMs: 0 }, + invocations: [], + cycles: [], + toolMetrics: [], + latestContextSize: undefined, + projectedContextSize: undefined, + } + } + + async saveSession(): Promise<{ tag: 'ok'; val: void } | { tag: 'err'; val: AgentError }> { + if (!this.sessionManager) return { tag: 'err', val: { tag: 'no-session-configured' } } + return asAgentResult(async () => { + await this.sessionManager!.saveSnapshot({ target: this.agent, isLatest: true }) + }, true) + } + + async listSnapshots(): Promise<{ tag: 'ok'; val: string[] } | { tag: 'err'; val: AgentError }> { + if (!this.sessionManager) return { tag: 'err', val: { tag: 'no-session-configured' } } + return asAgentResult(() => this.sessionManager!.listSnapshotIds({ target: this.agent }), true) + } + + async deleteSession(): Promise<{ tag: 'ok'; val: void } | { tag: 'err'; val: AgentError }> { + if (!this.sessionManager) return { tag: 'err', val: { tag: 'no-session-configured' } } + return { tag: 'err', val: { tag: 'internal', val: 'deleteSession not yet implemented' } } + } +} + +class EventStreamImpl { + private parent: ResponseStreamImpl + + constructor(parent: ResponseStreamImpl) { + this.parent = parent + } + + read(): Promise { + return this.parent._pullNext() + } +} + +class ResponseStreamImpl { + private done = false + private generator: AsyncGenerator + private interruptResolve: ((payload: string) => void) | null = null + private agent: Agent + private defaultTools: FunctionTool[] | undefined + private originalModel: Model | undefined + private pendingStop: WitStreamEvent | undefined + + constructor( + agent: Agent, + input: PromptInput, + defaultTools?: FunctionTool[], + originalModel?: Model, + structuredOutputSchema?: z.ZodSchema + ) { + this.agent = agent + this.defaultTools = defaultTools + this.originalModel = originalModel + this.generator = agent.stream(invokeInputFromWit(input), { structuredOutputSchema }) + } + + private restoreDefaults(): void { + if (this.originalModel) this.agent.model = this.originalModel + this.agent.toolRegistry.clear() + if (this.defaultTools) this.agent.toolRegistry.add(this.defaultTools) + } + + /** @internal Drains both the SDK iterator and any pending terminal stop. */ + async _pullNext(): Promise { + if (this.pendingStop) { + const stop = this.pendingStop + this.pendingStop = undefined + return stop + } + if (this.done) return undefined + while (true) { + try { + const result = await this.generator.next() + if (result.done) { + this.done = true + this.restoreDefaults() + return result.value ? mapStopEvent(result.value) : undefined + } + const mapped = mapEvent(result.value) + if (mapped) return mapped + // null means the SDK event has no on-stream representation; loop. + } catch (err) { + this.done = true + this.restoreDefaults() + return { tag: 'error', val: { tag: 'internal', val: errorMessage(err) } } + } + } + } + + events(): EventStreamImpl { + return new EventStreamImpl(this) + } + + respond(args: RespondArgs): { tag: 'ok'; val: void } | { tag: 'err'; val: AgentError } { + if (!this.interruptResolve) { + return { tag: 'err', val: { tag: 'unknown-interrupt', val: args.interruptId } } + } + // Phase 2: look up the interrupt by id and resolve the matching promise. + this.interruptResolve(args.response) + this.interruptResolve = null + return { tag: 'ok', val: undefined } + } + + cancel(): void { + this.done = true + this.restoreDefaults() + void this.generator.return(undefined) + } +} + +export const api = { + Agent: AgentImpl, + ResponseStream: ResponseStreamImpl, + EventStream: EventStreamImpl, +} + +// Exported for contract testing. Not used by the WASM component build. +export { mapEvent, mapStopEvent, mapStopReason, mapUsage, mapMetrics, mapMessage, mapContentBlock, createTools, createToolChoiceProxy, toolChoiceFromWit } diff --git a/strands-wasm/package.json b/strands-wasm/package.json new file mode 100644 index 0000000000..b7e59147c4 --- /dev/null +++ b/strands-wasm/package.json @@ -0,0 +1,30 @@ +{ + "name": "@strands-agents/wasm", + "version": "0.0.1-development", + "private": true, + "description": "WASM component build for Strands Agents SDK", + "type": "module", + "scripts": { + "generate": "jco guest-types ../wit --name strands:agent --world-name agent --out-dir generated", + "build": "node build.js", + "test": "vitest run --project unit", + "test:guest": "vitest run --project guest", + "test:guest:integ": "STRANDS_INTEG=true vitest run --project guest", + "transpile": "jco transpile dist/strands-agent.wasm -o dist/transpiled --instantiation async", + "type-check": "npm run generate && tsc", + "clean": "rm -rf dist node_modules package-lock.json" + }, + "dependencies": { + "@aws/bedrock-token-generator": "https://github.com/pgrayy/wasm-deps/releases/download/token-gen-v1.1.0/aws-bedrock-token-generator-1.1.0.tgz", + "@strands-agents/sdk": "*", + "zod": "^4.1.12" + }, + "devDependencies": { + "@bytecodealliance/jco": "^1.16.1", + "@bytecodealliance/preview2-shim": "^0.17.9", + "@chaynabors/componentize-js": "^0.19.3", + "esbuild": "^0.27.4", + "typescript": "^6.0.2", + "vitest": "^3.2.1" + } +} diff --git a/strands-wasm/test/guest/boundary.test.ts b/strands-wasm/test/guest/boundary.test.ts new file mode 100644 index 0000000000..104db02f12 --- /dev/null +++ b/strands-wasm/test/guest/boundary.test.ts @@ -0,0 +1,71 @@ +import { describe, it, expect, beforeAll, beforeEach } from 'vitest' +import { createGuest, drainStream, LogEntry } from './harness' + +describe('Level 2a: boundary smoke tests', () => { + const anthropicModel = { tag: 'anthropic' as const, val: { apiKey: 'sk-fake-key-for-testing' } } + let root: any + const logEntries: LogEntry[] = [] + + function createAgent(): any { + return new root.api.Agent({ model: anthropicModel }) + } + + beforeAll(async () => { + root = await createGuest({ + log: (entry) => logEntries.push(entry), + callTool: () => JSON.stringify({ status: 'success', content: [{ text: 'mock result' }] }), + }) + }, 120_000) + + beforeEach(() => { + logEntries.length = 0 + }) + + it('component loads and instantiate succeeds', () => { + expect(root).toBeDefined() + expect(root.api).toBeDefined() + expect(root.api.Agent).toBeDefined() + }) + + it('Agent construction succeeds', () => { + expect(createAgent()).toBeDefined() + }) + + it('getMessages returns empty array on fresh agent', () => { + expect(createAgent().getMessages()).toBe('[]') + }) + + it('setMessages → getMessages round-trips correctly', () => { + const agent = createAgent() + const messages = JSON.stringify([{ role: 'user', content: [{ type: 'text', text: 'hello' }] }]) + agent.setMessages({ json: messages }) + expect(agent.getMessages()).toBe(messages) + }) + + it('host-log mock receives log entries during construction', () => { + createAgent() + expect(logEntries.length).toBeGreaterThan(0) + expect(logEntries[0]).toMatchObject({ + level: expect.stringMatching(/^(trace|debug|info|warn|error)$/), + message: expect.any(String), + }) + }) + + it('generate with fake API key returns error event', async () => { + const agent = new root.api.Agent({ + model: { + ...anthropicModel, + val: { ...anthropicModel.val, additionalConfig: JSON.stringify({ timeout: 10_000 }) }, + }, + }) + const stream = agent.generate({ input: 'hello', tools: undefined, toolChoice: undefined }) + const events = await drainStream(stream) + const errorEvent = events.find((e: any) => e.tag === 'error') + expect(errorEvent).toBeDefined() + expect(typeof errorEvent.val).toBe('string') + }) + + it('deleteSession throws not-yet-implemented error', () => { + expect(() => createAgent().deleteSession()).toThrow() + }) +}) diff --git a/strands-wasm/test/guest/harness.ts b/strands-wasm/test/guest/harness.ts new file mode 100644 index 0000000000..b875a3d55a --- /dev/null +++ b/strands-wasm/test/guest/harness.ts @@ -0,0 +1,68 @@ +import { readFile } from 'node:fs/promises' +import { join } from 'node:path' +import { WASIShim } from '@bytecodealliance/preview2-shim/instantiation' + +const transpileDir = join(__dirname, '..', '..', 'dist', 'transpiled') + +/** Log entry forwarded from the WASM guest to the host. */ +export interface LogEntry { + level: string + message: string + context?: string +} + +/** Arguments passed from the WASM guest to the host tool-provider import. */ +export interface CallToolArgs { + name: string + input: string + toolUseId: string +} + +/** + * WIT Result type for batch tool calls (list\\>). + * jco does NOT unwrap list elements — the host must return the tagged variant. + */ +export type ToolResult = { tag: 'ok'; val: string } | { tag: 'err'; val: string } + +/** + * Host-side mock implementations injected into the WASM guest. + * + * callTool returns a plain string (success) or throws (error) — jco wraps the + * raw return into \{tag:'ok', val\} itself for WIT result\. + * callTools returns ToolResult[] directly — jco does NOT unwrap list elements. + */ +export interface HostMocks { + log: (entry: LogEntry) => void + callTool: (args: CallToolArgs) => string + callTools?: (args: { calls: CallToolArgs[] }) => ToolResult[] +} + +/** Compile and instantiate the WASM guest component with the given host mocks. */ +export async function createGuest(mocks: HostMocks): Promise { + const getCoreModule = async (path: string): Promise => { + const bytes = await readFile(join(transpileDir, path)) + return WebAssembly.compile(bytes) + } + + const { instantiate } = await import('../../dist/transpiled/strands-agent.js') + + return instantiate(getCoreModule, { + 'strands:agent/host-log': { log: mocks.log }, + 'strands:agent/tool-provider': { + callTool: mocks.callTool, + callTools: mocks.callTools ?? ((args: { calls: CallToolArgs[] }) => args.calls.map(mocks.callTool)), + }, + 'strands:agent/types': {}, + ...new WASIShim().getImportObject(), + }) +} + +/** Drain all batches from a guest ResponseStream into a flat event array. */ +export async function drainStream(stream: any): Promise { + const events: any[] = [] + let batch + while ((batch = await stream.readNext()) !== undefined) { + events.push(...batch) + } + return events +} diff --git a/strands-wasm/test/guest/roundtrip.test.ts b/strands-wasm/test/guest/roundtrip.test.ts new file mode 100644 index 0000000000..9b3cecb831 --- /dev/null +++ b/strands-wasm/test/guest/roundtrip.test.ts @@ -0,0 +1,144 @@ +import { describe, it, expect, vi, beforeAll, beforeEach } from 'vitest' +import { createGuest, drainStream, LogEntry, CallToolArgs } from './harness' + +interface ToolSpec { + name: string + description: string + inputSchema: string +} + +const bedrockConfig = { + model: { tag: 'bedrock' as const, val: { modelId: 'anthropic.claude-3-haiku-20240307-v1:0' } }, + modelParams: { maxTokens: 256 }, +} + +function generate(agent: any, input: string): any { + return agent.generate({ input, tools: undefined, toolChoice: undefined }) +} + +describe.runIf(process.env.STRANDS_INTEG === 'true')('Level 2b: full round-trip tests', () => { + let root: any + const logEntries: LogEntry[] = [] + const callToolMock = vi.fn((args: CallToolArgs) => { + return JSON.stringify({ status: 'success', content: [{ text: `mock result for ${args.name}` }] }) + }) + + beforeAll(async () => { + root = await createGuest({ + log: (entry) => logEntries.push(entry), + callTool: callToolMock, + }) + }, 120_000) + + beforeEach(() => { + logEntries.length = 0 + callToolMock.mockClear() + }) + + it('full generate produces text-delta and stop events', async () => { + const agent = new root.api.Agent({ + ...bedrockConfig, + systemPrompt: 'Respond with exactly one word: hello', + }) + const stream = generate(agent, 'Say hello') + const events = await drainStream(stream) + const textDeltas = events.filter((e: any) => e.tag === 'text-delta') + expect(textDeltas.length).toBeGreaterThan(0) + for (const td of textDeltas) { + expect(typeof td.val).toBe('string') + } + const stopEvent = events.find((e: any) => e.tag === 'stop') + expect(stopEvent).toBeDefined() + expect(stopEvent.val).toMatchObject({ + reason: 'end-turn', + }) + }) + + it('tool call flow — model calls tool, host mock receives it', async () => { + const weatherTool: ToolSpec = { + name: 'get_weather', + description: 'Get the current weather for a location', + inputSchema: JSON.stringify({ + type: 'object', + properties: { location: { type: 'string', description: 'City name' } }, + required: ['location'], + }), + } + const agent = new root.api.Agent({ + ...bedrockConfig, + systemPrompt: 'You have a get_weather tool. Use it to answer weather questions. Do not ask for clarification.', + tools: [weatherTool], + }) + const stream = generate(agent, 'What is the weather in Seattle?') + const events = await drainStream(stream) + expect(callToolMock).toHaveBeenCalled() + expect(callToolMock.mock.calls[0][0].name).toBe('get_weather') + const toolUseEvent = events.find((e: any) => e.tag === 'tool-use') + expect(toolUseEvent).toBeDefined() + expect(toolUseEvent.val.name).toBe('get_weather') + expect(typeof toolUseEvent.val.toolUseId).toBe('string') + expect(toolUseEvent.val.toolUseId.length).toBeGreaterThan(0) + expect(() => JSON.parse(toolUseEvent.val.input)).not.toThrow() + const toolResultEvent = events.find((e: any) => e.tag === 'tool-result') + expect(toolResultEvent).toBeDefined() + expect(toolResultEvent.val.status).toBe('success') + expect(typeof toolResultEvent.val.content).toBe('string') + }) + + it('lifecycle events appear in readNext batches', async () => { + const agent = new root.api.Agent({ ...bedrockConfig, systemPrompt: 'Say hi' }) + const stream = generate(agent, 'hello') + const events = await drainStream(stream) + const lifecycleEvents = events.filter((e: any) => e.tag === 'lifecycle') + expect(lifecycleEvents.length).toBeGreaterThan(0) + const beforeModelCall = lifecycleEvents.find((e: any) => e.val.eventType === 'before-model-call') + expect(beforeModelCall).toBeDefined() + expect(beforeModelCall.val).toMatchObject({ + eventType: 'before-model-call', + toolUse: undefined, + toolResult: undefined, + }) + }) + + it('metadata event with usage tokens appears', async () => { + const agent = new root.api.Agent({ ...bedrockConfig, systemPrompt: 'Say one word' }) + const stream = generate(agent, 'go') + const events = await drainStream(stream) + const metadataEvent = events.find((e: any) => e.tag === 'metadata') + expect(metadataEvent).toBeDefined() + expect(metadataEvent.val.usage).toBeDefined() + expect(metadataEvent.val.usage.inputTokens).toBeGreaterThan(0) + expect(metadataEvent.val.usage.outputTokens).toBeGreaterThanOrEqual(0) + expect(metadataEvent.val.usage.totalTokens).toBeGreaterThan(0) + }) + + it('cancel terminates the stream', async () => { + const agent = new root.api.Agent({ + ...bedrockConfig, + systemPrompt: 'Write a very long story about a dragon', + }) + const stream = generate(agent, 'begin') + const firstBatch = await stream.readNext() + expect(firstBatch).toBeDefined() + stream.cancel() + const afterCancel = await stream.readNext() + expect(afterCancel).toBeUndefined() + }) + + it('multi-turn: setMessages then generate continues context', async () => { + const agent = new root.api.Agent({ + ...bedrockConfig, + systemPrompt: 'Remember what the user tells you', + }) + const priorMessages = [ + { role: 'user', content: [{ type: 'text', text: 'My name is Alice' }] }, + { role: 'assistant', content: [{ type: 'text', text: 'Nice to meet you, Alice!' }] }, + ] + agent.setMessages({ json: JSON.stringify(priorMessages) }) + const stream = generate(agent, 'What is my name?') + const events = await drainStream(stream) + const textDeltas = events.filter((e: any) => e.tag === 'text-delta') + const fullText = textDeltas.map((e: any) => e.val).join('') + expect(fullText.toLowerCase()).toContain('alice') + }) +}) diff --git a/strands-wasm/tsconfig.json b/strands-wasm/tsconfig.json new file mode 100644 index 0000000000..284761478c --- /dev/null +++ b/strands-wasm/tsconfig.json @@ -0,0 +1,15 @@ +{ + "compilerOptions": { + "target": "ES2022", + "module": "NodeNext", + "moduleResolution": "nodenext", + "lib": ["ES2022", "DOM", "DOM.Iterable"], + "strict": true, + "noEmit": true, + "esModuleInterop": true, + "skipLibCheck": true, + "isolatedModules": true, + "verbatimModuleSyntax": true + }, + "include": ["entry.ts", "generated/**/*.d.ts"] +} diff --git a/strands-wasm/vitest.config.ts b/strands-wasm/vitest.config.ts new file mode 100644 index 0000000000..a22d5b46e9 --- /dev/null +++ b/strands-wasm/vitest.config.ts @@ -0,0 +1,30 @@ +import { defineConfig } from 'vitest/config' +import { resolve } from 'node:path' + +export default defineConfig({ + test: { + projects: [ + { + test: { + name: { label: 'unit' }, + include: ['__tests__/**/*.test.ts'], + }, + resolve: { + alias: { + 'strands:agent/tool-provider': resolve(__dirname, '__fixtures__/tool-provider.ts'), + 'strands:agent/host-log': resolve(__dirname, '__fixtures__/host-log.ts'), + '$/fixtures': resolve(__dirname, '../strands-ts/src/__fixtures__'), + }, + }, + }, + { + test: { + name: { label: 'guest' }, + include: ['test/guest/**/*.test.ts'], + testTimeout: 60_000, + pool: 'forks', + }, + }, + ], + }, +}) diff --git a/wit/agent.wit b/wit/agent.wit new file mode 100644 index 0000000000..d6d8b76ece --- /dev/null +++ b/wit/agent.wit @@ -0,0 +1,214 @@ +package strands:agent@0.1.0; + +/// Top-level agent API. Construct an `agent`, call `generate` to start an +/// invocation, and drain the returned `response-stream`. +interface api { + use messages.{message, content-block, prompt-input}; + use models.{model-config, model-params}; + use tools.{tool-spec, tool-choice, agent-as-tool-config}; + use sessions.{session-config, storage-error}; + use conversation.{conversation-manager-config}; + use retry.{retry-config}; + use streaming.{stream-event, agent-trace, agent-metrics}; + use vended.{vended-tool, vended-plugin}; + use mcp.{mcp-client-config}; + + /// Concurrent-execution options. + record concurrent-options { + /// Upper bound on tool calls running at once. Absent means no limit. + max-concurrency: option, + } + + /// Strategy for executing tool calls emitted in a single assistant turn. + variant tool-executor-strategy { + /// Run tool calls one at a time, in order. + sequential, + /// Run tool calls in parallel (default). + concurrent(concurrent-options), + } + + /// Scalar attribute value attached to a trace. + variant attribute-value { + /// String value. + string-value(string), + /// 64-bit signed integer. + int-value(s64), + /// 64-bit float. + double-value(f64), + /// Boolean. + bool-value(bool), + } + + /// Key-value pair attached to every OpenTelemetry span the agent emits. + /// Distinct from `streaming.trace-metadata-entry`, which is string-only. + record trace-attribute { + /// Attribute key. + key: string, + /// Attribute value. + value: attribute-value, + } + + /// W3C Trace Context headers linking the agent's spans to a caller's trace. + record trace-context { + /// `traceparent` header value. + traceparent: string, + /// `tracestate` header value. Absent when no vendor state is set. + tracestate: option, + } + + /// Display-level identity of the agent; all fields default to sensible values. + record agent-identity { + /// Display name. Defaults to `"Strands Agent"`. + name: option, + /// Stable identifier. Defaults to `"agent"`. + id: option, + /// Human-readable description of what the agent does. + description: option, + } + + /// Configuration passed to the `agent` constructor. + /// Invalid config surfaces on the first `generate` as `invalid-input`. + record agent-config { + /// Model provider. Defaults to Bedrock with a sensible model id when absent. + model: option, + /// Sampling parameters applied to every model call. + model-params: option, + /// Initial conversation history. + messages: option>, + /// System prompt. Either plain text or structured content blocks. + system-prompt: option, + /// Tools available to the model. Overridable per-invocation via `invoke-args.tools`. + tools: option>, + /// Child agents exposed as tools, registered alongside `tools`. + agent-tools: option>, + /// Built-in tools to enable. Added to `tools`. + vended-tools: option>, + /// Built-in plugins to enable. + vended-plugins: option>, + /// MCP clients whose tools should be exposed to the model. + mcp-clients: option>, + /// Display-level identity (name, id, description). + identity: option, + /// How tool calls from a single assistant turn are scheduled. + tool-executor: option, + /// Mirror agent output to the application's console. Defaults to `true`. + display-output: option, + /// Attributes added to every OpenTelemetry span. + trace-attributes: option>, + /// W3C Trace Context linking the agent's spans to a caller-supplied trace. + trace-context: option, + /// Session persistence. Absent means no persistence. + session: option, + /// Conversation history management. Defaults to a sliding window. + conversation-manager: option, + /// Retry policy for failed model calls. Defaults to exponential backoff capped at 6 attempts. + retry: option, + /// JSON Schema for structured output. Opaque on the wire. + structured-output-schema: option, + /// Initial app-state values as an opaque JSON object. + app-state: option, + /// Initial model-provider state as an opaque JSON object. Set when hydrating from a snapshot. + model-state: option, + } + + /// Arguments for `agent.generate`. + record invoke-args { + /// User input. + input: prompt-input, + /// Per-invocation tool override. Replaces the agent's registered tools. + tools: option>, + /// Tool choice policy. + tool-choice: option, + /// Per-invocation structured-output schema. Overrides the agent-level one. + structured-output-schema: option, + } + + /// Payload supplied when resuming from a human-in-the-loop interrupt. + record respond-args { + /// Id of the interrupt being responded to; matches the `interrupt` stream event. + interrupt-id: string, + /// User's response as a JSON value. Opaque on the wire. + response: string, + } + + /// Why an agent-resource call failed. + variant agent-error { + /// The agent was constructed without a session config. + no-session-configured, + /// The storage backend rejected the operation. + storage(storage-error), + /// Supplied payload did not match the expected shape. + invalid-input(string), + /// Supplied `interrupt-id` does not match any live interrupt. + unknown-interrupt(string), + /// Catch-all for internal failures. + internal(string), + } + + /// An agent instance. Persistent across `generate` calls. + resource agent { + /// Construct an agent from config. + constructor(config: agent-config); + /// Start a generation. Returns a handle bound to the in-flight call. + generate: func(args: invoke-args) -> response-stream; + /// Fetch the conversation history. + get-messages: func() -> list; + /// Replace the conversation history. + set-messages: func(messages: list) -> result<_, agent-error>; + /// Fetch app state as an opaque JSON object. + get-app-state: func() -> string; + /// Replace app state. Input is an opaque JSON object. + set-app-state: func(json: string) -> result<_, agent-error>; + /// Fetch model-provider state as an opaque JSON object. + get-model-state: func() -> string; + /// Replace model-provider state. Input is an opaque JSON object. + set-model-state: func(json: string) -> result<_, agent-error>; + /// Fetch in-memory traces. Returned flat since WIT lacks recursive records; reconstruct via `parent-id`. + get-traces: func() -> list; + /// Fetch a snapshot of the current metrics totals. + get-metrics: func() -> agent-metrics; + /// Persist the current session. + save-session: func() -> result<_, agent-error>; + /// List snapshot ids for the current session. + list-snapshots: func() -> result, agent-error>; + /// Delete the current session. + delete-session: func() -> result<_, agent-error>; + } + + /// Pull-based stream of agent events; sync-WIT placeholder for `stream`. + resource event-stream { + /// Pull the next event. `none` once the stream terminates. + read: func() -> option; + } + + /// Handle to an in-flight `generate` invocation. + resource response-stream { + /// Stream of events produced during the invocation. + events: func() -> event-stream; + /// Resume a human-in-the-loop interrupt with the user's response. + respond: func(args: respond-args) -> result<_, agent-error>; + /// Cancel the invocation. Fire-and-forget. + cancel: func(); + } +} + +/// Strands agent component. Implement the imports to plug in custom tools, +/// storage, models, and other extension points; the API is ready to call. +world agent { + /// Tools the application exposes to the agent's model. + import tool-provider; + /// Receives structured log entries from the agent. + import host-log; + /// Custom snapshot storage. Selected via `session-config.storage = custom`. + import snapshot-storage; + /// Custom snapshot policy. Selected via `session-config.save-latest = trigger(id)`. + import snapshot-trigger-handler; + /// Custom model provider. Selected via `model-config.custom`. + import model-provider; + /// Conditional graph-edge callbacks for multi-agent orchestration. + import edge-handler-registry; + /// Responds to MCP elicitation. Enabled per client via `mcp-client-config.elicitation-enabled`. + import elicitation-handler; + /// Agent API your application calls. + export api; +} diff --git a/wit/conversation.wit b/wit/conversation.wit new file mode 100644 index 0000000000..0c79228e04 --- /dev/null +++ b/wit/conversation.wit @@ -0,0 +1,41 @@ +package strands:agent@0.1.0; + +/// Conversation history management. +interface conversation { + use models.{model-config}; + + /// Sliding-window strategy: trim oldest messages once the conversation + /// exceeds `window-size`. + record sliding-window-config { + /// Maximum number of messages retained. + window-size: s32, + /// Drop older tool results when trimming. + should-truncate-results: bool, + } + + /// Summarizing strategy: once the conversation grows, summarize older + /// messages into a single summary message and keep the rest verbatim. + record summarizing-config { + /// Fraction of messages to summarize. Must be between 0.1 and 0.8; + /// out-of-range values surface on the first `generate` call as + /// `agent-error::invalid-input`. + summary-ratio: f64, + /// Minimum number of recent messages preserved verbatim. + preserve-recent-messages: s32, + /// System prompt used for the summarizer model. + summarization-system-prompt: option, + /// Summarizer model. Defaults to the agent's primary model when absent. + summarization-model: option, + } + + /// Which conversation manager the agent uses. + variant conversation-manager-config { + /// No conversation management. History grows without bound and + /// context-overflow errors propagate to the caller. + none, + /// Sliding-window trimming. + sliding-window(sliding-window-config), + /// Summarization of older messages. + summarizing(summarizing-config), + } +} diff --git a/wit/deps/clocks/clocks.wit b/wit/deps/clocks/clocks.wit new file mode 100644 index 0000000000..d638f1a40f --- /dev/null +++ b/wit/deps/clocks/clocks.wit @@ -0,0 +1,157 @@ +package wasi:clocks@0.2.6; + +/// WASI Monotonic Clock is a clock API intended to let users measure elapsed +/// time. +/// +/// It is intended to be portable at least between Unix-family platforms and +/// Windows. +/// +/// A monotonic clock is a clock which has an unspecified initial value, and +/// successive reads of the clock will produce non-decreasing values. +@since(version = 0.2.0) +interface monotonic-clock { + @since(version = 0.2.0) + use wasi:io/poll@0.2.6.{pollable}; + + /// An instant in time, in nanoseconds. An instant is relative to an + /// unspecified initial value, and can only be compared to instances from + /// the same monotonic-clock. + @since(version = 0.2.0) + type instant = u64; + + /// A duration of time, in nanoseconds. + @since(version = 0.2.0) + type duration = u64; + + /// Read the current value of the clock. + /// + /// The clock is monotonic, therefore calling this function repeatedly will + /// produce a sequence of non-decreasing values. + @since(version = 0.2.0) + now: func() -> instant; + + /// Query the resolution of the clock. Returns the duration of time + /// corresponding to a clock tick. + @since(version = 0.2.0) + resolution: func() -> duration; + + /// Create a `pollable` which will resolve once the specified instant + /// has occurred. + @since(version = 0.2.0) + subscribe-instant: func(when: instant) -> pollable; + + /// Create a `pollable` that will resolve after the specified duration has + /// elapsed from the time this function is invoked. + @since(version = 0.2.0) + subscribe-duration: func(when: duration) -> pollable; +} + +/// WASI Wall Clock is a clock API intended to let users query the current +/// time. The name "wall" makes an analogy to a "clock on the wall", which +/// is not necessarily monotonic as it may be reset. +/// +/// It is intended to be portable at least between Unix-family platforms and +/// Windows. +/// +/// A wall clock is a clock which measures the date and time according to +/// some external reference. +/// +/// External references may be reset, so this clock is not necessarily +/// monotonic, making it unsuitable for measuring elapsed time. +/// +/// It is intended for reporting the current date and time for humans. +@since(version = 0.2.0) +interface wall-clock { + /// A time and date in seconds plus nanoseconds. + @since(version = 0.2.0) + record datetime { + seconds: u64, + nanoseconds: u32, + } + + /// Read the current value of the clock. + /// + /// This clock is not monotonic, therefore calling this function repeatedly + /// will not necessarily produce a sequence of non-decreasing values. + /// + /// The returned timestamps represent the number of seconds since + /// 1970-01-01T00:00:00Z, also known as [POSIX's Seconds Since the Epoch], + /// also known as [Unix Time]. + /// + /// The nanoseconds field of the output is always less than 1000000000. + /// + /// [POSIX's Seconds Since the Epoch]: https://pubs.opengroup.org/onlinepubs/9699919799/xrat/V4_xbd_chap04.html#tag_21_04_16 + /// [Unix Time]: https://en.wikipedia.org/wiki/Unix_time + @since(version = 0.2.0) + now: func() -> datetime; + + /// Query the resolution of the clock. + /// + /// The nanoseconds field of the output is always less than 1000000000. + @since(version = 0.2.0) + resolution: func() -> datetime; +} + +@unstable(feature = clocks-timezone) +interface timezone { + @unstable(feature = clocks-timezone) + use wall-clock.{datetime}; + + /// Information useful for displaying the timezone of a specific `datetime`. + /// + /// This information may vary within a single `timezone` to reflect daylight + /// saving time adjustments. + @unstable(feature = clocks-timezone) + record timezone-display { + /// The number of seconds difference between UTC time and the local + /// time of the timezone. + /// + /// The returned value will always be less than 86400 which is the + /// number of seconds in a day (24*60*60). + /// + /// In implementations that do not expose an actual time zone, this + /// should return 0. + utc-offset: s32, + /// The abbreviated name of the timezone to display to a user. The name + /// `UTC` indicates Coordinated Universal Time. Otherwise, this should + /// reference local standards for the name of the time zone. + /// + /// In implementations that do not expose an actual time zone, this + /// should be the string `UTC`. + /// + /// In time zones that do not have an applicable name, a formatted + /// representation of the UTC offset may be returned, such as `-04:00`. + name: string, + /// Whether daylight saving time is active. + /// + /// In implementations that do not expose an actual time zone, this + /// should return false. + in-daylight-saving-time: bool, + } + + /// Return information needed to display the given `datetime`. This includes + /// the UTC offset, the time zone name, and a flag indicating whether + /// daylight saving time is active. + /// + /// If the timezone cannot be determined for the given `datetime`, return a + /// `timezone-display` for `UTC` with a `utc-offset` of 0 and no daylight + /// saving time. + @unstable(feature = clocks-timezone) + display: func(when: datetime) -> timezone-display; + + /// The same as `display`, but only return the UTC offset. + @unstable(feature = clocks-timezone) + utc-offset: func(when: datetime) -> s32; +} + +@since(version = 0.2.0) +world imports { + @since(version = 0.2.0) + import wasi:io/poll@0.2.6; + @since(version = 0.2.0) + import monotonic-clock; + @since(version = 0.2.0) + import wall-clock; + @unstable(feature = clocks-timezone) + import timezone; +} diff --git a/wit/deps/io/io.wit b/wit/deps/io/io.wit new file mode 100644 index 0000000000..08ad78e6b7 --- /dev/null +++ b/wit/deps/io/io.wit @@ -0,0 +1,331 @@ +package wasi:io@0.2.6; + +@since(version = 0.2.0) +interface error { + /// A resource which represents some error information. + /// + /// The only method provided by this resource is `to-debug-string`, + /// which provides some human-readable information about the error. + /// + /// In the `wasi:io` package, this resource is returned through the + /// `wasi:io/streams/stream-error` type. + /// + /// To provide more specific error information, other interfaces may + /// offer functions to "downcast" this error into more specific types. For example, + /// errors returned from streams derived from filesystem types can be described using + /// the filesystem's own error-code type. This is done using the function + /// `wasi:filesystem/types/filesystem-error-code`, which takes a `borrow` + /// parameter and returns an `option`. + /// + /// The set of functions which can "downcast" an `error` into a more + /// concrete type is open. + @since(version = 0.2.0) + resource error { + /// Returns a string that is suitable to assist humans in debugging + /// this error. + /// + /// WARNING: The returned string should not be consumed mechanically! + /// It may change across platforms, hosts, or other implementation + /// details. Parsing this string is a major platform-compatibility + /// hazard. + @since(version = 0.2.0) + to-debug-string: func() -> string; + } +} + +/// A poll API intended to let users wait for I/O events on multiple handles +/// at once. +@since(version = 0.2.0) +interface poll { + /// `pollable` represents a single I/O event which may be ready, or not. + @since(version = 0.2.0) + resource pollable { + /// Return the readiness of a pollable. This function never blocks. + /// + /// Returns `true` when the pollable is ready, and `false` otherwise. + @since(version = 0.2.0) + ready: func() -> bool; + /// `block` returns immediately if the pollable is ready, and otherwise + /// blocks until ready. + /// + /// This function is equivalent to calling `poll.poll` on a list + /// containing only this pollable. + @since(version = 0.2.0) + block: func(); + } + + /// Poll for completion on a set of pollables. + /// + /// This function takes a list of pollables, which identify I/O sources of + /// interest, and waits until one or more of the events is ready for I/O. + /// + /// The result `list` contains one or more indices of handles in the + /// argument list that is ready for I/O. + /// + /// This function traps if either: + /// - the list is empty, or: + /// - the list contains more elements than can be indexed with a `u32` value. + /// + /// A timeout can be implemented by adding a pollable from the + /// wasi-clocks API to the list. + /// + /// This function does not return a `result`; polling in itself does not + /// do any I/O so it doesn't fail. If any of the I/O sources identified by + /// the pollables has an error, it is indicated by marking the source as + /// being ready for I/O. + @since(version = 0.2.0) + poll: func(in: list>) -> list; +} + +/// WASI I/O is an I/O abstraction API which is currently focused on providing +/// stream types. +/// +/// In the future, the component model is expected to add built-in stream types; +/// when it does, they are expected to subsume this API. +@since(version = 0.2.0) +interface streams { + @since(version = 0.2.0) + use error.{error}; + @since(version = 0.2.0) + use poll.{pollable}; + + /// An error for input-stream and output-stream operations. + @since(version = 0.2.0) + variant stream-error { + /// The last operation (a write or flush) failed before completion. + /// + /// More information is available in the `error` payload. + /// + /// After this, the stream will be closed. All future operations return + /// `stream-error::closed`. + last-operation-failed(error), + /// The stream is closed: no more input will be accepted by the + /// stream. A closed output-stream will return this error on all + /// future operations. + closed, + } + + /// An input bytestream. + /// + /// `input-stream`s are *non-blocking* to the extent practical on underlying + /// platforms. I/O operations always return promptly; if fewer bytes are + /// promptly available than requested, they return the number of bytes promptly + /// available, which could even be zero. To wait for data to be available, + /// use the `subscribe` function to obtain a `pollable` which can be polled + /// for using `wasi:io/poll`. + @since(version = 0.2.0) + resource input-stream { + /// Perform a non-blocking read from the stream. + /// + /// When the source of a `read` is binary data, the bytes from the source + /// are returned verbatim. When the source of a `read` is known to the + /// implementation to be text, bytes containing the UTF-8 encoding of the + /// text are returned. + /// + /// This function returns a list of bytes containing the read data, + /// when successful. The returned list will contain up to `len` bytes; + /// it may return fewer than requested, but not more. The list is + /// empty when no bytes are available for reading at this time. The + /// pollable given by `subscribe` will be ready when more bytes are + /// available. + /// + /// This function fails with a `stream-error` when the operation + /// encounters an error, giving `last-operation-failed`, or when the + /// stream is closed, giving `closed`. + /// + /// When the caller gives a `len` of 0, it represents a request to + /// read 0 bytes. If the stream is still open, this call should + /// succeed and return an empty list, or otherwise fail with `closed`. + /// + /// The `len` parameter is a `u64`, which could represent a list of u8 which + /// is not possible to allocate in wasm32, or not desirable to allocate as + /// as a return value by the callee. The callee may return a list of bytes + /// less than `len` in size while more bytes are available for reading. + @since(version = 0.2.0) + read: func(len: u64) -> result, stream-error>; + /// Read bytes from a stream, after blocking until at least one byte can + /// be read. Except for blocking, behavior is identical to `read`. + @since(version = 0.2.0) + blocking-read: func(len: u64) -> result, stream-error>; + /// Skip bytes from a stream. Returns number of bytes skipped. + /// + /// Behaves identical to `read`, except instead of returning a list + /// of bytes, returns the number of bytes consumed from the stream. + @since(version = 0.2.0) + skip: func(len: u64) -> result; + /// Skip bytes from a stream, after blocking until at least one byte + /// can be skipped. Except for blocking behavior, identical to `skip`. + @since(version = 0.2.0) + blocking-skip: func(len: u64) -> result; + /// Create a `pollable` which will resolve once either the specified stream + /// has bytes available to read or the other end of the stream has been + /// closed. + /// The created `pollable` is a child resource of the `input-stream`. + /// Implementations may trap if the `input-stream` is dropped before + /// all derived `pollable`s created with this function are dropped. + @since(version = 0.2.0) + subscribe: func() -> pollable; + } + + /// An output bytestream. + /// + /// `output-stream`s are *non-blocking* to the extent practical on + /// underlying platforms. Except where specified otherwise, I/O operations also + /// always return promptly, after the number of bytes that can be written + /// promptly, which could even be zero. To wait for the stream to be ready to + /// accept data, the `subscribe` function to obtain a `pollable` which can be + /// polled for using `wasi:io/poll`. + /// + /// Dropping an `output-stream` while there's still an active write in + /// progress may result in the data being lost. Before dropping the stream, + /// be sure to fully flush your writes. + @since(version = 0.2.0) + resource output-stream { + /// Check readiness for writing. This function never blocks. + /// + /// Returns the number of bytes permitted for the next call to `write`, + /// or an error. Calling `write` with more bytes than this function has + /// permitted will trap. + /// + /// When this function returns 0 bytes, the `subscribe` pollable will + /// become ready when this function will report at least 1 byte, or an + /// error. + @since(version = 0.2.0) + check-write: func() -> result; + /// Perform a write. This function never blocks. + /// + /// When the destination of a `write` is binary data, the bytes from + /// `contents` are written verbatim. When the destination of a `write` is + /// known to the implementation to be text, the bytes of `contents` are + /// transcoded from UTF-8 into the encoding of the destination and then + /// written. + /// + /// Precondition: check-write gave permit of Ok(n) and contents has a + /// length of less than or equal to n. Otherwise, this function will trap. + /// + /// returns Err(closed) without writing if the stream has closed since + /// the last call to check-write provided a permit. + @since(version = 0.2.0) + write: func(contents: list) -> result<_, stream-error>; + /// Perform a write of up to 4096 bytes, and then flush the stream. Block + /// until all of these operations are complete, or an error occurs. + /// + /// This is a convenience wrapper around the use of `check-write`, + /// `subscribe`, `write`, and `flush`, and is implemented with the + /// following pseudo-code: + /// + /// ```text + /// let pollable = this.subscribe(); + /// while !contents.is_empty() { + /// // Wait for the stream to become writable + /// pollable.block(); + /// let Ok(n) = this.check-write(); // eliding error handling + /// let len = min(n, contents.len()); + /// let (chunk, rest) = contents.split_at(len); + /// this.write(chunk ); // eliding error handling + /// contents = rest; + /// } + /// this.flush(); + /// // Wait for completion of `flush` + /// pollable.block(); + /// // Check for any errors that arose during `flush` + /// let _ = this.check-write(); // eliding error handling + /// ``` + @since(version = 0.2.0) + blocking-write-and-flush: func(contents: list) -> result<_, stream-error>; + /// Request to flush buffered output. This function never blocks. + /// + /// This tells the output-stream that the caller intends any buffered + /// output to be flushed. the output which is expected to be flushed + /// is all that has been passed to `write` prior to this call. + /// + /// Upon calling this function, the `output-stream` will not accept any + /// writes (`check-write` will return `ok(0)`) until the flush has + /// completed. The `subscribe` pollable will become ready when the + /// flush has completed and the stream can accept more writes. + @since(version = 0.2.0) + flush: func() -> result<_, stream-error>; + /// Request to flush buffered output, and block until flush completes + /// and stream is ready for writing again. + @since(version = 0.2.0) + blocking-flush: func() -> result<_, stream-error>; + /// Create a `pollable` which will resolve once the output-stream + /// is ready for more writing, or an error has occurred. When this + /// pollable is ready, `check-write` will return `ok(n)` with n>0, or an + /// error. + /// + /// If the stream is closed, this pollable is always ready immediately. + /// + /// The created `pollable` is a child resource of the `output-stream`. + /// Implementations may trap if the `output-stream` is dropped before + /// all derived `pollable`s created with this function are dropped. + @since(version = 0.2.0) + subscribe: func() -> pollable; + /// Write zeroes to a stream. + /// + /// This should be used precisely like `write` with the exact same + /// preconditions (must use check-write first), but instead of + /// passing a list of bytes, you simply pass the number of zero-bytes + /// that should be written. + @since(version = 0.2.0) + write-zeroes: func(len: u64) -> result<_, stream-error>; + /// Perform a write of up to 4096 zeroes, and then flush the stream. + /// Block until all of these operations are complete, or an error + /// occurs. + /// + /// This is a convenience wrapper around the use of `check-write`, + /// `subscribe`, `write-zeroes`, and `flush`, and is implemented with + /// the following pseudo-code: + /// + /// ```text + /// let pollable = this.subscribe(); + /// while num_zeroes != 0 { + /// // Wait for the stream to become writable + /// pollable.block(); + /// let Ok(n) = this.check-write(); // eliding error handling + /// let len = min(n, num_zeroes); + /// this.write-zeroes(len); // eliding error handling + /// num_zeroes -= len; + /// } + /// this.flush(); + /// // Wait for completion of `flush` + /// pollable.block(); + /// // Check for any errors that arose during `flush` + /// let _ = this.check-write(); // eliding error handling + /// ``` + @since(version = 0.2.0) + blocking-write-zeroes-and-flush: func(len: u64) -> result<_, stream-error>; + /// Read from one stream and write to another. + /// + /// The behavior of splice is equivalent to: + /// 1. calling `check-write` on the `output-stream` + /// 2. calling `read` on the `input-stream` with the smaller of the + /// `check-write` permitted length and the `len` provided to `splice` + /// 3. calling `write` on the `output-stream` with that read data. + /// + /// Any error reported by the call to `check-write`, `read`, or + /// `write` ends the splice and reports that error. + /// + /// This function returns the number of bytes transferred; it may be less + /// than `len`. + @since(version = 0.2.0) + splice: func(src: borrow, len: u64) -> result; + /// Read from one stream and write to another, with blocking. + /// + /// This is similar to `splice`, except that it blocks until the + /// `output-stream` is ready for writing, and the `input-stream` + /// is ready for reading, before performing the `splice`. + @since(version = 0.2.0) + blocking-splice: func(src: borrow, len: u64) -> result; + } +} + +@since(version = 0.2.0) +world imports { + @since(version = 0.2.0) + import error; + @since(version = 0.2.0) + import poll; + @since(version = 0.2.0) + import streams; +} diff --git a/wit/logging.wit b/wit/logging.wit new file mode 100644 index 0000000000..e711ea65d2 --- /dev/null +++ b/wit/logging.wit @@ -0,0 +1,32 @@ +package strands:agent@0.1.0; + +/// Structured logging emitted by the agent. Your application receives +/// entries via the `log` method. +interface host-log { + /// Severity level of a log entry. + enum log-level { + /// Fine-grained diagnostics, typically disabled in production. + trace, + /// Debugging detail useful during development. + debug, + /// Routine operational information. + info, + /// Recoverable issue worth surfacing. + warn, + /// Failure that may require attention. + error, + } + + /// A single structured log entry. + record log-entry { + /// Severity. + level: log-level, + /// Human-readable message. + message: string, + /// Structured context as a JSON object. + context: option, + } + + /// Emit a log entry. + log: func(entry: log-entry); +} diff --git a/wit/mcp.wit b/wit/mcp.wit new file mode 100644 index 0000000000..28d469ed71 --- /dev/null +++ b/wit/mcp.wit @@ -0,0 +1,148 @@ +package strands:agent@0.1.0; + +/// Model Context Protocol (MCP) client configuration. +interface mcp { + use wasi:clocks/monotonic-clock@0.2.6.{duration}; + use tools.{tool-spec}; + + /// Connection state of an MCP client. + enum mcp-connection-state { + /// Not connected. + disconnected, + /// Connected and ready. + connected, + /// Connection failed. + failed, + } + + /// How the client talks to the MCP server. + variant mcp-transport { + /// STDIO transport. Spawn a local process and talk via pipes. + stdio(stdio-transport-config), + /// Streamable HTTP transport, per the current MCP specification. + streamable-http(http-transport-config), + /// Legacy Server-Sent Events transport. Retained for older servers. + sse(sse-transport-config), + } + + /// STDIO transport configuration. + record stdio-transport-config { + /// Command to execute. + command: string, + /// Arguments passed to the command. + args: list, + /// Extra environment variables to set for the child process. + env: list, + /// Working directory for the child process. + cwd: option, + } + + /// Single environment variable entry. + record env-var { + /// Variable name. + key: string, + /// Variable value. + value: string, + } + + /// HTTP transport configuration. + record http-transport-config { + /// Server endpoint URL. + url: string, + /// Extra HTTP headers. + headers: list, + } + + /// SSE transport configuration. + record sse-transport-config { + /// Server endpoint URL. + url: string, + /// Extra HTTP headers. + headers: list, + } + + /// Single HTTP header entry. + record http-header { + /// Header name. + name: string, + /// Header value. + value: string, + } + + /// Task-augmented tool execution. Enables long-running tools with + /// progress tracking. Experimental in the MCP specification. + record tasks-config { + /// Time-to-live for task polling. + ttl: duration, + /// Maximum time to wait for task completion while polling. + poll-timeout: duration, + } + + /// MCP client configuration. + record mcp-client-config { + /// Stable identifier passed back on every elicitation call. One + /// application can register multiple MCP clients under distinct ids. + client-id: string, + /// Application name advertised to the MCP server. + application-name: option, + /// Application version advertised to the MCP server. + application-version: option, + /// Transport configuration. + transport: mcp-transport, + /// When set, enables task-augmented tool invocation. + tasks-config: option, + /// Whether the client advertises elicitation support. + elicitation-enabled: bool, + /// Whether connection failures log a warning instead of throwing. + fail-open: bool, + /// Disable OpenTelemetry MCP instrumentation. + disable-instrumentation: bool, + } +} + +/// Pluggable elicitation handler for MCP servers requesting user input. +/// Enabled per client via `mcp-client-config.elicitation-enabled`. +interface elicitation-handler { + /// Request for user input. + record elicit-request { + /// Which MCP client is making the request. + client-id: string, + /// Human-readable prompt to show the user. + message: string, + /// Request schema as a JSON value. Either a JSON Schema describing a + /// form, or a URL-mode payload. + request: string, + } + + /// Outcome of an elicitation request. + enum elicit-action { + /// User accepted and provided content. + accept, + /// User declined. + decline, + /// User cancelled (e.g. by closing the prompt). + cancel, + } + + /// Response to an elicitation request. + record elicit-response { + /// User's decision. + action: elicit-action, + /// Content matching the requested schema, as a JSON value. Present + /// only when `action` is `accept`. + content: option, + } + + /// Why an elicitation call failed. + variant elicitation-error { + /// No handler registered for the given `client-id`. + unknown-client(string), + /// Handler raised an exception. + handler-failed(string), + /// Request timed out waiting for a human response. + timed-out, + } + + /// Prompt the user for input and return their response. + elicit: func(request: elicit-request) -> result; +} diff --git a/wit/messages.wit b/wit/messages.wit new file mode 100644 index 0000000000..9af1a18427 --- /dev/null +++ b/wit/messages.wit @@ -0,0 +1,357 @@ +package strands:agent@0.1.0; + +/// Content blocks that make up a message. +interface messages { + /// Plain text. + record text-block { + /// Text content. + text: string, + } + + /// Object stored in Amazon S3. + record s3-location { + /// URI in `s3://bucket/key` form. + uri: string, + /// Owning AWS account, for cross-account access. + bucket-owner: option, + } + + /// Source of image bytes. + variant image-source { + /// Raw image bytes. + bytes(list), + /// Publicly-accessible URL. + url(string), + /// Object in S3. + s3(s3-location), + } + + /// Image attached to a message. + record image-block { + /// Provider-accepted format, e.g. `png`, `jpg`, `jpeg`, `gif`, `webp`. + format: string, + /// Where the image bytes come from. + source: image-source, + } + + /// Source of video bytes. + variant video-source { + /// Raw video bytes. + bytes(list), + /// Object in S3. + s3(s3-location), + } + + /// Video attached to a message. + record video-block { + /// Provider-accepted format, e.g. `mp4`, `mov`, `webm`, `3gp`. + format: string, + /// Where the video bytes come from. + source: video-source, + } + + /// Source of document bytes. + variant document-source { + /// Raw document bytes. + bytes(list), + /// Plain text content. + text(string), + /// Structured content made of text blocks. + content(list), + /// Object in S3. + s3(s3-location), + } + + /// Citation configuration attached to a document. + record document-citations-config { + /// Whether the model should cite spans from this document. + enabled: bool, + } + + /// Document attached to a message. + record document-block { + /// Display name shown to the model. + name: string, + /// Provider-accepted format, e.g. `pdf`, `csv`, `docx`, `md`, `json`. + format: string, + /// Where the document bytes come from. + source: document-source, + /// Citation configuration. Absent means citations are disabled. + citations: option, + /// Additional context to prepend to the document. + context: option, + } + + /// Model's thought process. Either plain reasoning (with an optional + /// signature) or an opaque redacted blob. + record reasoning-block { + /// Reasoning text. + text: option, + /// Cryptographic signature for verification. + signature: option, + /// Opaque redacted reasoning, when the provider withheld the plain form. + redacted-content: option>, + } + + /// Prompt-caching kind. More arms will be added as providers surface + /// additional cache tiers (e.g. Anthropic's `ephemeral`). + enum cache-kind { + /// Standard provider-default caching. + default-cache, + } + + /// Marks a caching boundary in the prompt. + record cache-point-block { + /// Cache kind. + kind: cache-kind, + } + + /// How a piece of guard content should be evaluated. + enum guard-qualifier { + /// Content is a reference source the model should ground its answer on. + grounding-source, + /// Content is the user's query. + query, + /// Content is subject to guardrail policy evaluation. + guard-content, + } + + /// Text submitted to a guardrail for evaluation. + record guard-content-text { + /// How the text should be evaluated. + qualifiers: list, + /// Text content. + text: string, + } + + /// Image submitted to a guardrail for evaluation. + record guard-content-image { + /// `png` or `jpeg`. + format: string, + /// Raw image bytes. + bytes: list, + } + + /// Content submitted to a guardrail for evaluation. + variant guard-content-block { + /// Text guard content. + text(guard-content-text), + /// Image guard content. + image(guard-content-image), + } + + /// Range within a source document (characters, pages, or chunks). + record document-range { + /// Index of the source document in the input list. + document-index: s32, + /// Inclusive start offset. + start: s32, + /// Exclusive end offset. + end: s32, + } + + /// Range within a search result. + record search-result-range { + /// Index of the search result in the input list. + search-result-index: s32, + /// Inclusive start offset. + start: s32, + /// Exclusive end offset. + end: s32, + } + + /// Web citation target. + record web-location { + /// Cited URL. + url: string, + /// Domain of the cited URL, if the provider surfaces it separately. + domain: option, + } + + /// Anchor a citation points to. + variant citation-location { + /// Character range within a document. + document-char(document-range), + /// Page range within a document. + document-page(document-range), + /// Chunk range within a document. + document-chunk(document-range), + /// Range within a search result. + search-result(search-result-range), + /// Web page. + web(web-location), + } + + /// Text fragment from a source or a generated answer. + record citation-text { + /// Text content. + text: string, + } + + /// Link from generated content back to a source location. + record citation { + /// Where the citation points. + location: citation-location, + /// Opaque source identifier. + source: string, + /// Excerpts from the source. + source-content: list, + /// Display title of the source. + title: string, + } + + /// Citations emitted by the model when citations are enabled. + record citations-block { + /// Citations linking generated text to sources. + citations: list, + /// Generated text that the citations annotate. + content: list, + } + + /// Model's request to call a tool. + record tool-use-block { + /// Tool to invoke. + name: string, + /// Identifier correlating this call with its result. + tool-use-id: string, + /// Arguments as a JSON value; shape is tool-specific. + input: string, + /// Reasoning signature from thinking models. Round-trip back to the model. + reasoning-signature: option, + } + + /// Whether a tool invocation succeeded. Richer classification lives on `tools.tool-error`. + enum tool-result-status { + /// Tool completed successfully. + success, + /// Tool returned an error result. + error, + } + + /// Block valid inside `tool-result-block.content`. Narrower than `content-block`. + variant tool-result-content { + /// Text output. + text(text-block), + /// Structured JSON output. + json(json-block), + /// Image output. + image(image-block), + /// Video output. + video(video-block), + /// Document output. + document(document-block), + } + + /// Outcome of a tool execution. + record tool-result-block { + /// Matching tool-use-id from the originating call. + tool-use-id: string, + /// Whether the call succeeded. + status: tool-result-status, + /// Content emitted by the tool. + content: list, + } + + /// Structured JSON payload. Used for tool results and agent-as-tool + /// outputs that carry schema-validated data, not prose. + record json-block { + /// JSON value. + json: string, + } + + /// User response to a previously-raised interrupt. Supplied on the + /// next invocation to resume the paused agent. + record interrupt-response-block { + /// Id of the interrupt being responded to. + interrupt-id: string, + /// User's response as a JSON value. + response: string, + } + + /// Any block that can appear inside a message. + variant content-block { + /// Plain text. + text(text-block), + /// Structured JSON payload. + json(json-block), + /// Model requested a tool call. + tool-use(tool-use-block), + /// Tool call completed. + tool-result(tool-result-block), + /// Model reasoning. + reasoning(reasoning-block), + /// Caching boundary marker. + cache-point(cache-point-block), + /// Content submitted for guardrail evaluation. + guard-content(guard-content-block), + /// Image. + image(image-block), + /// Video. + video(video-block), + /// Document. + document(document-block), + /// Citations emitted by the model. + citations(citations-block), + /// Response to a prior interrupt, supplied when resuming. + interrupt-response(interrupt-response-block), + } + + /// Who a message is from. + enum role { + /// Human input. + user, + /// Model response. + assistant, + } + + /// Token consumption for a model invocation. + record usage { + /// Tokens sent to the model. + input-tokens: s32, + /// Tokens generated by the model. + output-tokens: s32, + /// Convenience sum of input and output tokens. + total-tokens: s32, + /// Input tokens served from the provider's cache. + cache-read-input-tokens: option, + /// Input tokens written to the provider's cache. + cache-write-input-tokens: option, + } + + /// Performance metrics for a model invocation. + record metrics { + /// Wall-clock latency in milliseconds. + latency-ms: f64, + } + + /// Metadata attached to a message. Not sent to model providers; persisted + /// alongside the message for bookkeeping. + record message-metadata { + /// Token usage for this message. + usage: option, + /// Performance metrics for this message. + metrics: option, + /// Arbitrary application-level metadata as an opaque JSON object. + /// The agent does not interpret this. + custom: option, + } + + /// A complete message in a conversation. + record message { + /// Speaker. + role: role, + /// Ordered content blocks making up the message. + content: list, + /// Optional bookkeeping data. + metadata: option, + } + + /// A prompt-style input: either prose or structured content. Used for + /// both system prompts and user input. + variant prompt-input { + /// Plain text prompt. + text(string), + /// Structured content blocks. + blocks(list), + } +} diff --git a/wit/models.wit b/wit/models.wit new file mode 100644 index 0000000000..9a6be6e83c --- /dev/null +++ b/wit/models.wit @@ -0,0 +1,165 @@ +package strands:agent@0.1.0; + +/// Model provider configuration and pluggable custom providers. +interface models { + /// Anthropic API model configuration. + record anthropic-config { + /// Model identifier, e.g. `claude-opus-4-7`. + model-id: option, + /// API key. Falls back to the `ANTHROPIC_API_KEY` environment variable. + api-key: option, + /// Provider-specific overrides as a JSON object. + additional-config: option, + } + + /// AWS Bedrock model configuration. + record bedrock-config { + /// Bedrock model identifier, e.g. `us.anthropic.claude-opus-4-7-v1:0`. + model-id: string, + /// AWS region. Falls back to the default credential chain. + region: option, + /// Explicit AWS access key id. Falls back to the credential chain. + access-key-id: option, + /// Explicit AWS secret access key. Falls back to the credential chain. + secret-access-key: option, + /// Explicit AWS session token, for temporary credentials. + session-token: option, + /// Provider-specific overrides as a JSON object. + additional-config: option, + } + + /// OpenAI API model configuration. + record openai-config { + /// Model identifier, e.g. `gpt-4o`. + model-id: option, + /// API key. Falls back to the `OPENAI_API_KEY` environment variable. + api-key: option, + /// Provider-specific overrides as a JSON object. + additional-config: option, + } + + /// Google Gemini API model configuration. + record gemini-config { + /// Model identifier, e.g. `gemini-2.0-flash`. + model-id: option, + /// API key. Falls back to the `GOOGLE_API_KEY` environment variable. + api-key: option, + /// Provider-specific overrides as a JSON object. + additional-config: option, + } + + /// Custom model provider supplied by your application. + record custom-model-config { + /// Identifier routed back on each call. One application can register + /// multiple providers under distinct ids. + provider-id: string, + /// Model identifier passed through to your implementation. + model-id: option, + /// Implementation-specific overrides as a JSON object. + additional-config: option, + /// Stateful providers see only the latest message; local history clears each call. + stateful: bool, + } + + /// Which model provider the agent should use. + variant model-config { + /// Anthropic API. + anthropic(anthropic-config), + /// AWS Bedrock. + bedrock(bedrock-config), + /// OpenAI API. + openai(openai-config), + /// Google Gemini API. + gemini(gemini-config), + /// Custom provider supplied by your application. Implement the + /// `model-provider` interface to serve it. + custom(custom-model-config), + } + + /// Sampling parameters applied to every call on the chosen provider. + record model-params { + /// Upper bound on generated tokens per response. + max-tokens: option, + /// Sampling temperature. + temperature: option, + /// Nucleus sampling probability mass. + top-p: option, + } + + /// Why a model call failed. Retry logic keys off of which arm fires, so + /// implementations should pick the narrowest one that fits. + variant model-error { + /// No provider registered for the given `provider-id`. + unknown-provider(string), + /// Provider refused the request due to malformed input. + invalid-request(string), + /// Caller lacks permission (missing or expired credentials). + unauthorized(string), + /// Provider returned a rate-limit error. Retry after a backoff. + throttled(string), + /// Provider returned a server-side error. Retry may succeed. + server-error(string), + /// Request exceeded the model's context window. + context-window-exceeded, + /// Content was rejected by provider safety policy. + content-filtered(string), + /// Transient network or transport failure. Retry may succeed. + transient(string), + /// Catch-all for internal failures. + internal(string), + } +} + +/// Pluggable model provider. Selected via `model-config.custom`. +/// `provider-id` is passed on every call so one impl can serve many providers. +interface model-provider { + use models.{model-error}; + use messages.{message, prompt-input}; + use tools.{tool-spec, tool-choice}; + use streaming.{stream-event}; + + /// Pull-based stream of model events from a custom provider; host produces, guest reads. + resource model-event-stream { + /// Pull the next event. `none` once the stream terminates. + read: func() -> option; + } + + /// Options passed alongside the messages on each streaming call. + record model-stream-options { + /// System prompt. Either plain text or structured content blocks. + system-prompt: option, + /// Tools advertised to the model for this call. + tools: option>, + /// Tool choice policy. + tool-choice: option, + } + + /// Arguments for `start-stream`. + record start-stream-args { + /// Which custom provider instance is being called. + provider-id: string, + /// Conversation history. + messages: list, + /// Call-time options. + options: model-stream-options, + } + + /// Arguments for `count-tokens`. + record count-tokens-args { + /// Which custom provider instance is being called. + provider-id: string, + /// Messages to estimate. + messages: list, + /// Optional system prompt influencing the count. + system-prompt: option, + /// Optional tool specs influencing the count. + tools: option>, + } + + /// Start a streaming generation. Return an async stream; cancellation + /// is signalled by the caller dropping the reader. + start-stream: func(args: start-stream-args) -> model-event-stream; + + /// Count tokens for the given input, for proactive context management. + count-tokens: func(args: count-tokens-args) -> result; +} diff --git a/wit/multiagent.wit b/wit/multiagent.wit new file mode 100644 index 0000000000..84383e00c1 --- /dev/null +++ b/wit/multiagent.wit @@ -0,0 +1,250 @@ +package strands:agent@0.1.0; + +/// Multi-agent orchestration: Graph (dependency DAG) and Swarm (handoff). +interface multi-agent { + use wasi:clocks/monotonic-clock@0.2.6.{duration}; + use messages.{content-block, prompt-input, usage, metrics}; + use streaming.{stream-event}; + + /// Lifecycle status of a node or overall run. + enum orchestration-status { + /// Not started. + pending, + /// Running. + executing, + /// Finished successfully. + completed, + /// Finished with an error. + failed, + /// Cancelled before or during processing. + cancelled, + } + + /// Terminal status of a node or run. + enum terminal-status { + /// Finished successfully. + completed, + /// Finished with an error. + failed, + /// Cancelled before or during processing. + cancelled, + } + + /// What a node is. + enum node-kind { + /// Wraps a single Agent. + agent, + /// Wraps a nested multi-agent orchestrator. + multi-agent, + } + + /// Definition of an agent-backed node. + record agent-node-config { + /// Node identifier, unique within its graph/swarm. + id: string, + /// Human-readable description. + description: option, + /// Per-node wall-clock ceiling. Falls back to the enclosing + /// orchestrator's `node-timeout`. + timeout: option, + /// Agent configuration as a JSON value matching `api.agent-config`. + agent-config: string, + } + + /// Definition of a node that wraps another orchestrator. + record multi-agent-node-config { + /// Node identifier, unique within its parent graph/swarm. + id: string, + /// Human-readable description. + description: option, + /// Nested orchestrator as a JSON value matching `graph-config` or + /// `swarm-config`. + orchestrator: string, + } + + /// Any node a graph or swarm can execute. + variant node-config { + /// Wraps a single agent. + agent(agent-node-config), + /// Wraps a nested orchestrator. + multi-agent(multi-agent-node-config), + } + + /// Condition attached to a graph edge. + record edge-handler { + /// Handler identifier. Resolved against the callbacks registered via + /// `edge-handler-registry`. + handler-id: string, + } + + /// Edge connecting two graph nodes. + record edge-config { + /// Source node id. + source: string, + /// Target node id. + target: string, + /// Handler controlling whether the edge fires. Absent means always. + handler: option, + } + + /// Runtime configuration for a Graph. + record graph-config { + /// Identifier of this graph. + id: string, + /// Nodes making up the graph. + nodes: list, + /// Edges connecting the nodes. + edges: list, + /// Explicit source nodes. Empty means auto-detect (nodes with no incoming edges). + sources: list, + /// Max nodes running in parallel. Absent means no limit. + max-concurrency: option, + /// Max total node executions. Absent means no limit. + max-steps: option, + /// Wall-clock ceiling for the whole graph. Absent means no limit. + timeout: option, + /// Fallback per-node wall-clock ceiling. Absent means no limit. + node-timeout: option, + } + + /// Runtime configuration for a Swarm. + record swarm-config { + /// Identifier of this swarm. + id: string, + /// Agent-backed nodes available for handoff. + nodes: list, + /// Agent that runs first. + start-node-id: string, + /// Max total agent executions. Absent means no limit. + max-steps: option, + /// Wall-clock ceiling for the whole swarm. Absent means no limit. + timeout: option, + /// Fallback per-node wall-clock ceiling. Absent means no limit. + node-timeout: option, + } + + /// Why a node or run ended in `failed` status. + variant node-error { + /// An underlying agent or nested orchestrator failed. + execution(string), + /// Wall-clock ceiling was exceeded. + timeout, + /// A declared runtime limit (max-steps, max-concurrency) was hit. + limit-exceeded(string), + /// Edge handler rejected the traversal with an error. + edge-handler(string), + /// Invalid configuration detected at run time. + invalid-config(string), + /// Catch-all for internal failures. + internal(string), + } + + /// Result of a single node execution. + record node-result { + /// Node identifier. + node-id: string, + /// Terminal status. + status: terminal-status, + /// Wall-clock duration. + duration: duration, + /// Content produced by the node, in order. + content: list, + /// Error payload when `status` is `failed`. + error: option, + /// Validated structured output as a JSON value. Present when a schema + /// was supplied. + structured-output: option, + /// Token usage for the node. + usage: option, + /// Performance metrics for the node. + metrics: option, + } + + /// Final result of a graph or swarm run. + record multi-agent-result { + /// Overall status. + status: terminal-status, + /// Per-node results, in execution order. + nodes: list, + /// Total elapsed wall-clock time. + duration: duration, + /// Summed token usage across all nodes, including partial usage + /// from failed or cancelled nodes where the provider reports it. + usage: option, + /// Summed performance metrics across all nodes, including partial + /// metrics from failed or cancelled nodes. + metrics: option, + } + + /// Arguments for invoking a graph or swarm. + record multi-agent-invoke-args { + /// Task input. + input: prompt-input, + /// Invocation-scoped state bag as an opaque JSON object. + invocation-state: option, + } + + /// Payload for `node-start`. + record node-start-data { + /// Node identifier. + node-id: string, + /// Whether this node wraps a single agent or a nested orchestrator. + kind: node-kind, + } + + /// Payload for `node-event`. Carries a nested stream event from a + /// running node. + record node-event-data { + /// Node that produced the event. + node-id: string, + /// Inner event emitted by the node's invocation. + event: stream-event, + } + + /// Events emitted while streaming a multi-agent run. + variant multi-agent-stream-event { + /// A node began executing. + node-start(node-start-data), + /// A nested stream event from a running node. + nested(node-event-data), + /// A node finished executing. + node-stop(node-result), + /// A handoff happened between nodes. + handoff(handoff-event), + /// Terminal result for the run. + run-complete(multi-agent-result), + } + + /// Payload for a handoff edge firing. + record handoff-event { + /// Nodes the run moves from (usually one). + from-node-ids: list, + /// Nodes the run moves to. + to-node-ids: list, + } +} + +/// Pluggable edge-handler registry for graph edges that need custom routing. +interface edge-handler-registry { + use multi-agent.{node-result}; + + /// Why an edge evaluation failed. + variant edge-handler-error { + /// No handler registered for the given id. + unknown(string), + /// Handler raised an exception. + failed(string), + } + + /// State snapshot passed to `evaluate` so the handler can branch on + /// prior node results. + record handler-state { + /// Results accumulated so far, keyed by node id. + results: list, + /// Total node executions completed. + execution-count: s32, + } + + /// Decide whether an edge should be traversed. + evaluate: func(handler-id: string, state: handler-state) -> result; +} diff --git a/wit/retry.wit b/wit/retry.wit new file mode 100644 index 0000000000..aac70815cd --- /dev/null +++ b/wit/retry.wit @@ -0,0 +1,76 @@ +package strands:agent@0.1.0; + +/// Retry policy for failed model and tool calls. +interface retry { + use wasi:clocks/monotonic-clock@0.2.6.{duration}; + + /// How much random variation to apply to computed delays. + enum jitter-kind { + /// No jitter applied. + none, + /// Uniform random in `[0, delay]`. + full, + /// Uniform random in `[delay/2, delay]`. + equal, + /// Decorrelated exponential jitter (AWS-style). + decorrelated, + } + + /// Fixed delay between attempts. + record constant-backoff-config { + /// Delay returned for every retry. + delay: duration, + } + + /// Delay grows linearly with attempt number. + record linear-backoff-config { + /// Base delay. Delay on attempt N is `base * N`. + base: duration, + /// Upper bound applied before jitter. + max: duration, + /// Jitter mode. + jitter: jitter-kind, + } + + /// Delay grows exponentially with attempt number. + record exponential-backoff-config { + /// Base delay on the first retry. + base: duration, + /// Upper bound applied before jitter. + max: duration, + /// Growth factor. Delay on attempt N is `base * factor^(N-1)`. + factor: f64, + /// Jitter mode. + jitter: jitter-kind, + } + + /// Backoff curve applied between attempts. + variant backoff-strategy { + /// Fixed delay. + constant(constant-backoff-config), + /// Linear growth. + linear(linear-backoff-config), + /// Exponential growth. + exponential(exponential-backoff-config), + } + + /// A single retry strategy. Default is exponential backoff, full jitter, 6 attempts. + record model-retry-strategy { + /// Maximum number of attempts, including the initial call. + max-attempts: s32, + /// Backoff curve applied between attempts. + backoff: backoff-strategy, + /// Upper bound on total retry window. Absent means no cap. + total-budget: option, + } + + /// Retry configuration attached to an agent. + /// Every strategy observes every failure; the first to request a delay wins. + /// Empty list disables retries; omitting `agent-config.retry` applies a default + /// single exponential strategy. Two strategies with the same `backoff` arm + /// surface as `agent-error::invalid-input`. + record retry-config { + /// Strategies evaluated on every retryable failure. + strategies: list, + } +} diff --git a/wit/sessions.wit b/wit/sessions.wit new file mode 100644 index 0000000000..921f6115fc --- /dev/null +++ b/wit/sessions.wit @@ -0,0 +1,298 @@ +package strands:agent@0.1.0; + +/// Session persistence configuration and pluggable storage backends. +interface sessions { + use wasi:clocks/wall-clock@0.2.6.{datetime}; + use messages.{message}; + + /// Local filesystem snapshot storage. + record file-storage-config { + /// Directory under which snapshots are written. + base-dir: string, + } + + /// S3 snapshot storage. + record s3-storage-config { + /// Target bucket. + bucket: string, + /// AWS region. Falls back to the default credential chain. + region: option, + /// Key prefix under which snapshots are stored. + prefix: option, + } + + /// Reference to an application-implemented storage backend. + record custom-storage-config { + /// Identifier routed back to the `snapshot-storage` handler on every + /// call. One application can register multiple backends under + /// distinct ids. + backend-id: string, + } + + /// Where to persist session snapshots. + variant storage-config { + /// Local filesystem. + file(file-storage-config), + /// Amazon S3. + s3(s3-storage-config), + /// Application-implemented backend. + custom(custom-storage-config), + } + + /// When to update the "latest" snapshot pointer. The `trigger` arm + /// carries the id of an application-supplied callback that decides + /// per-invocation. + variant save-latest-policy { + /// After every message added to the conversation. + message, + /// Once per invocation, after it completes. + invocation, + /// Each invocation consults the named `snapshot-trigger-handler`. + /// The id identifies which handler to invoke. + trigger(string), + } + + /// Session persistence configuration attached to an agent. + record session-config { + /// Identifier for this session's snapshots. + session-id: string, + /// Storage backend. + storage: storage-config, + /// When to update the "latest" snapshot. Absent uses the `invocation` + /// default. + save-latest: option, + } + + /// Which kind of state a snapshot describes. + enum snapshot-scope { + /// Single-agent state. + agent, + /// Multi-agent orchestrator state. + multi-agent, + } + + /// Locator for a snapshot within the storage hierarchy. + record snapshot-location { + /// Session identifier. + session-id: string, + /// What kind of state this snapshot holds. + scope: snapshot-scope, + /// Scope-specific identifier (agent id or multi-agent id). + scope-id: string, + } + + /// Sliding-window conversation manager state at snapshot time. + record sliding-window-state { + /// Number of messages dropped from the front of history this lifetime. + removed-message-count: s32, + } + + /// Summarizing conversation manager state at snapshot time. + record summarizing-state { + /// Current summary message, carrying the accumulated summary text. + /// Absent before the first summarization runs. + summary-message: option, + /// Number of messages removed via summarization. + removed-message-count: s32, + } + + /// Conversation manager snapshot state. Which arm is populated depends + /// on the conversation manager the agent was built with. + variant conversation-manager-state { + /// No conversation manager or null manager; nothing to persist. + none, + /// Sliding-window manager state. + sliding-window(sliding-window-state), + /// Summarizing manager state. + summarizing(summarizing-state), + } + + /// Retry-strategy state at snapshot time. + record retry-strategy-state { + /// Attempts used against the current model call's budget. + attempts-used: s32, + /// Milliseconds elapsed against the strategy's total budget. + elapsed-ms: s64, + } + + /// Named plugin state. `data` is an opaque JSON object owned by the plugin. + record plugin-state-entry { + /// Plugin identifier, matches the plugin's `name`. + plugin-name: string, + /// Plugin-owned state as an opaque JSON object. + data: string, + } + + /// Framework-owned snapshot state. All fields are optional because an + /// agent may not exercise every subsystem in a given run. + record snapshot-data { + /// Conversation history at snapshot time. + messages: list, + /// Conversation manager state. + conversation-manager: option, + /// Retry strategy state. + retry-strategy: option, + /// Model-provider state (e.g. server-side session id for stateful + /// providers) as an opaque JSON object. + model-state: option, + /// Per-plugin state for plugins the framework doesn't model directly. + plugins: list, + } + + /// Point-in-time capture of agent or orchestrator state. + record snapshot { + /// Which scope this snapshot belongs to. + scope: snapshot-scope, + /// Schema version string for forward compatibility. + schema-version: string, + /// Wall-clock time the snapshot was created. + created-at: datetime, + /// Framework-owned state. + data: snapshot-data, + /// Application-owned data as an opaque JSON object. The agent does + /// not read or modify this. + app-data: string, + } + + /// Metadata describing the snapshot manifest file. + record snapshot-manifest { + /// Schema version of the manifest. + schema-version: string, + /// Wall-clock time of the most recent manifest update. + updated-at: datetime, + } + + /// Why a snapshot operation failed. + variant storage-error { + /// No snapshot or manifest at the requested location. + not-found, + /// Caller lacks permission to read or write the storage. + access-denied(string), + /// Backing storage is full or over quota. + out-of-space, + /// Snapshot is malformed or cannot be deserialized. + corrupt(string), + /// Concurrent writers collided; retrying may succeed. + conflict(string), + /// Transient I/O failure; retrying may succeed. + transient(string), + /// Permanent backend failure. + permanent(string), + /// No custom backend registered for the given backend-id. + unknown-backend(string), + } +} + +/// Pluggable snapshot storage. Selected via `storage-config = custom(...)`. +/// `backend-id` is passed on every call so one impl can serve many backends. +interface snapshot-storage { + use sessions.{snapshot, snapshot-location, snapshot-manifest, storage-error}; + + /// Arguments for `save-snapshot`. + record save-snapshot-args { + /// Backend this call targets. + backend-id: string, + /// Where to write the snapshot. + location: snapshot-location, + /// Snapshot identifier. Agent-assigned UUID v7; treat as opaque. + snapshot-id: string, + /// Whether this snapshot should become the new "latest" pointer. + is-latest: bool, + /// Snapshot data. + snapshot: snapshot, + } + + /// Arguments for `load-snapshot`. + record load-snapshot-args { + /// Backend this call targets. + backend-id: string, + /// Which session/scope to load from. + location: snapshot-location, + /// Specific snapshot id. Absent loads the "latest" snapshot. + snapshot-id: option, + } + + /// Arguments for `list-snapshot-ids`. + record list-snapshot-ids-args { + /// Backend this call targets. + backend-id: string, + /// Which session/scope to list. + location: snapshot-location, + /// Cap on returned ids. + limit: option, + /// Exclusive cursor. Pass the last id returned by the previous page. + start-after: option, + } + + /// Arguments for `delete-session`. + record delete-session-args { + /// Backend this call targets. + backend-id: string, + /// Session to delete. + session-id: string, + } + + /// Arguments for `load-manifest` / `save-manifest`. + record manifest-args { + /// Backend this call targets. + backend-id: string, + /// Which session/scope's manifest to address. + location: snapshot-location, + } + + /// Arguments for `save-manifest`. + record save-manifest-args { + /// Backend this call targets. + backend-id: string, + /// Which session/scope's manifest to update. + location: snapshot-location, + /// New manifest. + manifest: snapshot-manifest, + } + + /// Persist a snapshot. + save-snapshot: func(args: save-snapshot-args) -> result<_, storage-error>; + + /// Load a snapshot. Returns `ok(none)` when the location is empty. + load-snapshot: func(args: load-snapshot-args) -> result, storage-error>; + + /// List snapshot ids for a session scope, chronologically. + list-snapshot-ids: func(args: list-snapshot-ids-args) -> result, storage-error>; + + /// Delete every snapshot for the given session id. + delete-session: func(args: delete-session-args) -> result<_, storage-error>; + + /// Load the manifest for a session scope. + load-manifest: func(args: manifest-args) -> result; + + /// Save the manifest for a session scope. + save-manifest: func(args: save-manifest-args) -> result<_, storage-error>; +} + +/// Pluggable snapshot trigger called after each invocation. +/// Enabled via `session-config.save-latest = trigger(id)`. +interface snapshot-trigger-handler { + use messages.{message}; + + /// Context passed to the trigger on each call. + record trigger-params { + /// Identifier of the trigger, matching the id supplied in + /// `save-latest-policy.trigger`. + trigger-id: string, + /// Total messages in the agent's conversation history. + message-count: s32, + /// Most recent message. Absent when history is empty. + last-message: option, + } + + /// Why a trigger evaluation failed. + variant trigger-error { + /// No trigger registered for the given id. + unknown(string), + /// Trigger raised an exception. + failed(string), + } + + /// Return true to write a new snapshot, false to skip. + should-snapshot: func(params: trigger-params) -> result; +} diff --git a/wit/streaming.wit b/wit/streaming.wit new file mode 100644 index 0000000000..7443558473 --- /dev/null +++ b/wit/streaming.wit @@ -0,0 +1,368 @@ +package strands:agent@0.1.0; + +/// Events emitted by the agent during a `generate` call. +interface streaming { + use messages.{content-block, tool-use-block, tool-result-block, message, usage, metrics}; + use models.{model-error}; + use tools.{tool-error}; + + /// Human-in-the-loop interrupt raised by a tool or hook. + record interrupt { + /// Unique identifier. Passed back as `respond-args.interrupt-id` + /// when resuming. + id: string, + /// User-defined name for the interrupt. + name: string, + /// Reason as an opaque JSON value. Absent when no reason was set. + reason: option, + } + + /// Why the model stopped generating. + enum stop-reason { + /// Natural end of the model's turn. + end-turn, + /// Model paused to call a tool. + tool-use, + /// Hit the configured token limit. + max-tokens, + /// Provider returned an error. + error, + /// Content was filtered by provider safety policy. + content-filtered, + /// A guardrail policy intervened. + guardrail-intervened, + /// A configured stop sequence was encountered. + stop-sequence, + /// Input exceeded the model's context window. + model-context-window-exceeded, + /// The caller cancelled the invocation. + cancelled, + } + + /// Usage and metrics accumulated so far. + record metadata-event { + /// Cumulative token usage. + usage: option, + /// Cumulative performance metrics. + metrics: option, + } + + /// Single key-value pair attached to a trace. Values are string-typed + /// to keep traces compact; structured payloads belong on `message`. + record trace-metadata-entry { + /// Metadata key. + key: string, + /// Metadata value. + value: string, + } + + /// In-memory trace node. Returned flat; reconstruct the tree via `parent-id`. + record agent-trace { + /// Unique identifier. + id: string, + /// Human-readable display name, e.g. `Cycle 1`, `Tool: calc`. + name: string, + /// Parent trace id. Absent for root traces. + parent-id: option, + /// Start time, milliseconds since epoch. + start-time-ms: s64, + /// End time, milliseconds since epoch. Absent while in progress. + end-time-ms: option, + /// Duration in milliseconds (`end-time - start-time`). + duration-ms: s64, + /// Metadata attached to this trace. + metadata: list, + /// Message associated with this trace. Absent when not applicable. + message: option, + } + + /// Per-tool execution metrics keyed by tool name in `agent-metrics`. + record tool-metrics { + /// Tool name. + tool-name: string, + /// Total calls. + call-count: s32, + /// Successful calls. + success-count: s32, + /// Failed calls. + error-count: s32, + /// Total execution time across all calls, in milliseconds. + total-time-ms: s64, + } + + /// Per-invocation metrics. Cycles are flattened into `agent-metrics.cycles` + /// and linked back via `invocation-id`. + record invocation-metrics { + /// Unique identifier for this invocation. + invocation-id: string, + /// Accumulated token usage for this invocation. + usage: usage, + } + + /// Per-cycle usage tracking. + record agent-loop-metrics { + /// Unique identifier for this cycle. + cycle-id: string, + /// Invocation this cycle belongs to. + invocation-id: string, + /// Duration of this cycle in milliseconds. + duration-ms: s64, + /// Token usage for this cycle. + usage: usage, + } + + /// Snapshot of agent metrics. Returned by `agent.get-metrics`. + record agent-metrics { + /// Total cycle count across the agent's lifetime. + cycle-count: s32, + /// Accumulated token usage. + accumulated-usage: usage, + /// Accumulated performance metrics. + accumulated-metrics: metrics, + /// Per-invocation totals. + invocations: list, + /// Per-cycle metrics across every invocation. Link back to an + /// invocation via `cycle.invocation-id`. + cycles: list, + /// Per-tool metrics. + tool-metrics: list, + /// Current context window utilization, measured as the input token + /// count from the most recent model call. + latest-context-size: option, + /// Projected context size for the next call. + projected-context-size: option, + } + + /// Mutable tool-use descriptor. `before-tool-call` hooks may rewrite fields. + record tool-use-data { + /// Tool to invoke. + name: string, + /// Identifier correlating this call with its result. + tool-use-id: string, + /// Arguments as a JSON value. + input: string, + } + + /// Redaction information when guardrails block content. + record hook-redaction { + /// Replacement text for the redacted user message; the original is in history. + user-message: string, + } + + /// Model response surfaced on `after-model-call`. + record model-stop-data { + /// Message returned by the model. + message: message, + /// Why the model stopped generating. + stop-reason: stop-reason, + /// Redaction info when guardrails blocked input. Absent if no redaction. + redaction: option, + } + + /// Payload for `before-invocation`. + record before-invocation-data { + /// Invocation-scoped state bag as a JSON object. Always present; + /// an empty object signals no caller-supplied state. + invocation-state: string, + } + + /// Payload for `after-invocation`. + record after-invocation-data { + /// Invocation-scoped state bag as a JSON object. Always present; + /// an empty object signals no caller-supplied state. + invocation-state: string, + } + + /// Payload for `message-added`. + record message-added-data { + /// Message appended to the conversation. + message: message, + } + + /// Payload for `before-model-call`. + record before-model-call-data { + /// Projected input token count for the upcoming call. Absent when + /// the provider doesn't report it. + projected-input-tokens: option, + } + + /// Payload for `after-model-call`. + record after-model-call-data { + /// 1-indexed attempt count for this turn. + attempt-count: s32, + /// Model response. Absent when an error occurred before completion. + stop-data: option, + /// Error when the call failed. Absent on success. + error: option, + } + + /// Payload for `before-tool-call`. + record before-tool-call-data { + /// Tool-use descriptor about to execute. + tool-use: tool-use-data, + } + + /// Payload for `after-tool-call`. + record after-tool-call-data { + /// Tool-use that ran. + tool-use: tool-use-data, + /// Tool result block. + tool-result: tool-result-block, + /// Error when the tool threw. Absent on success. + error: option, + } + + /// Payload for `before-tools` / `after-tools`. + record tools-batch-data { + /// Assistant message whose tool calls are about to run (or just ran). + message: message, + } + + /// Payload for `content-block`. + record content-block-data { + /// Fully-assembled content block. + content-block: content-block, + } + + /// Payload for `model-message`. + record model-message-data { + /// Assembled assistant message. + message: message, + /// Why the model stopped. + stop-reason: stop-reason, + } + + /// Payload for `tool-result-hook`. + record tool-result-data { + /// Completed tool-result block. + tool-result: tool-result-block, + } + + /// Payload for `tool-stream-update`. + record tool-stream-update-data { + /// Data from the streaming tool as a JSON value. + data: string, + } + + /// Payload for `model-stream-update`. + record model-stream-update-data { + /// Inner model stream event as a JSON value. + event: string, + } + + /// Payload for `agent-result`. + record agent-result-data { + /// Terminal stop event: stop reason, final usage, structured output. + stop: stop-event, + } + + /// Input redaction emitted when a guardrail blocks input. Original is in history. + record input-redaction { + /// Text to substitute for the blocked input. + replace-content: string, + } + + /// Output redaction emitted when a guardrail blocks output. + record output-redaction { + /// Original blocked content if the provider surfaced it. + redacted-content: option, + /// Text to substitute for the blocked output. + replace-content: string, + } + + /// Redaction event. Input and output fields are independent; at least one is set. + record redaction-event { + /// Present when input was redacted. + input-redaction: option, + /// Present when output was redacted. + output-redaction: option, + } + + /// Terminal event for a stream. + record stop-event { + /// Why generation stopped. + reason: stop-reason, + /// Final token usage. + usage: option, + /// Final performance metrics. + metrics: option, + /// Validated structured output as a JSON value. Present when a schema + /// was supplied. + structured-output: option, + } + + /// Why the agent loop surfaced an error mid-stream. + variant stream-error { + /// A model call failed. + model(model-error), + /// A tool call failed. + tool(tool-error), + /// Input exceeded the model's context window and no conversation + /// manager could recover. + context-window-exceeded, + /// Exceeded the model's max-tokens budget mid-response. + max-tokens-reached, + /// Structured output was requested but the model never called the + /// tool, even after being forced. + structured-output-unavailable, + /// Catch-all for internal failures. + internal(string), + } + + /// Events yielded during agent streaming. + /// Hot-path arms: `text-delta`, `tool-use`, `tool-result`. Other content + /// blocks flow through `content`. Lifecycle arms (`before-invocation` + /// through `agent-result`) mirror a hook system and can be filtered by tag. + variant stream-event { + /// Incremental text from the model. + text-delta(string), + /// Model requested a tool call. + tool-use(tool-use-block), + /// Tool call completed. + tool-result(tool-result-block), + /// Non-hot-path content block (image, reasoning, citations, etc). + content(content-block), + /// Cumulative usage and metrics snapshot. + metadata(metadata-event), + /// Terminal event for the stream. + stop(stop-event), + /// Guardrail redaction fired. + redaction(redaction-event), + /// Recoverable error surfaced mid-stream. + error(stream-error), + /// Human-in-the-loop pause; resume via `response-stream.respond`. + interrupt(interrupt), + /// Agent finished construction. + initialized, + /// About to process a user invocation. + before-invocation(before-invocation-data), + /// Finished processing a user invocation. + after-invocation(after-invocation-data), + /// A message was appended to the conversation. + message-added(message-added-data), + /// About to call the model. + before-model-call(before-model-call-data), + /// Model call returned. + after-model-call(after-model-call-data), + /// About to run a batch of tool calls from one assistant turn. + before-tools(tools-batch-data), + /// Tool batch finished. + after-tools(tools-batch-data), + /// About to call a single tool. + before-tool-call(before-tool-call-data), + /// Tool call returned. + after-tool-call(after-tool-call-data), + /// A content block was assembled during streaming. + content-block(content-block-data), + /// Model finished producing a full message. + model-message(model-message-data), + /// Tool finished execution (completion event, not streaming update). + tool-result-hook(tool-result-data), + /// Streaming update from a tool. + tool-update(tool-stream-update-data), + /// Streaming update from the model. + model-update(model-stream-update-data), + /// Final event for an invocation, carrying the terminal result. + agent-result(agent-result-data), + } +} diff --git a/wit/tools.wit b/wit/tools.wit new file mode 100644 index 0000000000..689450aa97 --- /dev/null +++ b/wit/tools.wit @@ -0,0 +1,102 @@ +package strands:agent@0.1.0; + +/// Tool definitions and tool execution. +interface tools { + use messages.{tool-result-content}; + + /// Declaration of a tool the model can call. + record tool-spec { + /// Unique tool identifier. + name: string, + /// Natural-language description shown to the model. + description: string, + /// JSON Schema describing the tool's parameters. + input-schema: string, + } + + /// Wrap a configured agent as a tool callable by the parent agent. The + /// child agent is instantiated at registration time. + record agent-as-tool-config { + /// Tool name exposed to the parent's model. Defaults to the child + /// agent's `name`. + name: option, + /// Tool description exposed to the parent's model. Defaults to the + /// child agent's `description` or a generic description. + description: option, + /// Whether the child retains its conversation history across calls. + /// `false` (default) resets state to construction-time on every call. + preserve-context: bool, + /// Child agent configuration as a JSON value matching `api.agent-config`. + /// Embedded here because a typed reference would be recursive. + agent-config: string, + } + + /// Arguments for a single tool call. + record call-tool-args { + /// Tool to invoke. + name: string, + /// Arguments as a JSON value. Shape is tool-specific. + input: string, + /// Identifier correlating this call with its result. + tool-use-id: string, + } + + /// Policy controlling whether and how the model calls tools on the next + /// generation step. + variant tool-choice { + /// Model decides whether to call a tool. + auto, + /// Model must call at least one tool. + any, + /// Model must call the tool with this name. + named(string), + } + + /// Incremental event emitted by a streaming tool while running. + variant tool-stream-event { + /// Partial progress data as a JSON value. + data(string), + /// Final completion (success). + complete(list), + /// Terminal error. + error(tool-error), + } + + /// Pull-based stream of tool events. Sync-WIT placeholder for + /// `stream`. + resource tool-event-stream { + /// Pull the next event. `none` after the terminal `complete` or `error`. + read: func() -> option; + } + + /// Why a tool call failed. + variant tool-error { + /// No tool registered under the given name. + unknown(string), + /// Tool input didn't match the declared input schema. + invalid-input(string), + /// Tool ran but returned an error result. + execution-failed(string), + /// Tool exceeded its time budget. + timed-out, + /// Tool was cancelled before completion. + cancelled, + /// Catch-all for internal failures. + internal(string), + } +} + +/// Tool execution. Your application implements this to expose tools to +/// the agent's model. +/// +/// Every call returns a stream. Non-streaming tools emit a single +/// `complete(...)` or `error(...)` event and close. Streaming tools emit +/// zero or more `data(...)` events before terminating with `complete` or +/// `error`. +interface tool-provider { + use tools.{call-tool-args, tool-event-stream}; + + /// Execute a tool. Iterate the returned stream to completion; the + /// final event is `complete(...)` on success or `error(...)` on failure. + call-tool: func(args: call-tool-args) -> tool-event-stream; +} diff --git a/wit/vended.wit b/wit/vended.wit new file mode 100644 index 0000000000..7b7088e852 --- /dev/null +++ b/wit/vended.wit @@ -0,0 +1,84 @@ +package strands:agent@0.1.0; + +/// Opt-in configuration for bundled tools and plugins that ship with the +/// agent. Enabling one of these is sufficient. No separate registration +/// or implementation is required on your side. +interface vended { + /// Built-in tools. + variant vended-tool { + /// Run shell commands in a persistent bash session. + bash(bash-tool-config), + /// Create, view, and edit files on disk. + file-editor(file-editor-tool-config), + /// Make HTTP requests. + http-request(http-request-tool-config), + /// Read and execute Jupyter notebook cells. + notebook(notebook-tool-config), + } + + /// Bash tool configuration. + record bash-tool-config { + /// Default timeout for `execute` calls, in seconds. + default-timeout-s: option, + } + + /// File editor tool configuration. + record file-editor-tool-config { + /// Directory outside of which the tool refuses to operate. Absent + /// permits any path the sandbox grants. + workspace-root: option, + } + + /// HTTP request tool configuration. + record http-request-tool-config { + /// Hosts the tool is allowed to reach. Empty permits any host the + /// sandbox permits. + allowed-hosts: list, + /// Upper bound on response size, in bytes. 0 means unbounded. + max-response-bytes: u64, + } + + /// Notebook tool configuration. + record notebook-tool-config { + /// Directory outside of which the tool refuses to operate. Absent + /// permits any path the sandbox grants. + workspace-root: option, + } + + /// Location of a skill definition on disk. + record skill-source { + /// Path to the skill directory. + path: string, + } + + /// Skills plugin configuration. + record skills-plugin-config { + /// Skill sources to load. + skills: list, + /// Fail if a skill cannot be loaded. + strict: bool, + /// Maximum resource files loaded per skill. + max-resource-files: option, + /// State-store key used to track active skills. + state-key: option, + } + + /// Context offloader plugin configuration. + record context-offloader-plugin-config { + /// Token threshold at which tool results are offloaded. + max-result-tokens: option, + /// Tokens to keep inline when offloading (as a preview). + preview-tokens: option, + /// Whether to register a retrieval tool the model can call to pull + /// offloaded content back in. + include-retrieval-tool: bool, + } + + /// Built-in plugins. + variant vended-plugin { + /// Load and activate Anthropic-style skills from disk. + skills(skills-plugin-config), + /// Offload large tool results to external storage. + context-offloader(context-offloader-plugin-config), + } +}