feat: add ruff linter/formatter and fix Python 3.10 test compatibility
- Add ruff configuration to pyproject.toml with sensible defaults - Add pre-commit hooks for ruff check and format - Update CI to run ruff check and format verification on tests/ - Fix Python 3.10 mock.patch compatibility in CLI tests - Use importlib to get actual modules, bypassing shadowed click groups - Add get_cli_module() helper and update patch_client_for_module() - Apply ruff formatting to all source and test files - Move .claude/skills to gitignore (local-only skills) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
892de6f76d
commit
9bd715af0d
81 changed files with 1548 additions and 1358 deletions
|
|
@ -1,57 +0,0 @@
|
|||
---
|
||||
name: matrix
|
||||
description: Run tests across multiple Python versions using Docker containers
|
||||
---
|
||||
|
||||
# Multi-Version Test Runner
|
||||
|
||||
Run the test suite across Python 3.10-3.14 using Docker containers in parallel.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
./dev/test-versions.sh
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
```bash
|
||||
# Run all versions (3.10, 3.11, 3.12, 3.13, 3.14)
|
||||
./dev/test-versions.sh
|
||||
|
||||
# Run specific versions only
|
||||
./dev/test-versions.sh 3.12 3.13
|
||||
|
||||
# Include readonly e2e tests (requires auth)
|
||||
./dev/test-versions.sh -r
|
||||
|
||||
# Pass pytest arguments (after --)
|
||||
./dev/test-versions.sh -- -k test_encoder -v
|
||||
|
||||
# Combine options
|
||||
./dev/test-versions.sh -r 3.12 -- -k test_encoder
|
||||
```
|
||||
|
||||
## When to Use
|
||||
|
||||
- Before committing changes that might affect Python version compatibility
|
||||
- When testing syntax or feature compatibility across versions
|
||||
- To validate all supported Python versions pass locally before CI
|
||||
|
||||
## Requirements
|
||||
|
||||
- Docker must be running
|
||||
- First run pulls Python images (~50MB each)
|
||||
- Subsequent runs use cached pip packages for speed (~30s per version)
|
||||
|
||||
## Output
|
||||
|
||||
Shows pass/fail status for each Python version with test summary.
|
||||
On failure, displays last 30 lines of output for debugging.
|
||||
|
||||
## Cache Management
|
||||
|
||||
To clear pip cache volumes:
|
||||
```bash
|
||||
docker volume ls | grep notebooklm-pip-cache | xargs docker volume rm
|
||||
```
|
||||
5
.github/workflows/test.yml
vendored
5
.github/workflows/test.yml
vendored
|
|
@ -29,7 +29,10 @@ jobs:
|
|||
pip install -e ".[all]"
|
||||
|
||||
- name: Run linting
|
||||
run: ruff check src/
|
||||
run: ruff check src/ tests/
|
||||
|
||||
- name: Check formatting
|
||||
run: ruff format --check src/ tests/
|
||||
|
||||
- name: Run type checking
|
||||
run: mypy src/notebooklm --ignore-missing-imports
|
||||
|
|
|
|||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -17,3 +17,4 @@ captured_rpcs/
|
|||
.worktrees/
|
||||
.worktree/
|
||||
.sisyphus/
|
||||
.claude/
|
||||
|
|
|
|||
7
.pre-commit-config.yaml
Normal file
7
.pre-commit-config.yaml
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.8.6
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
|
|
@ -89,3 +89,31 @@ exclude = ["tests/"]
|
|||
module = "notebooklm.cli.*"
|
||||
warn_return_any = false
|
||||
strict_optional = true
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
line-length = 100
|
||||
src = ["src", "tests"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"UP", # pyupgrade
|
||||
"SIM", # flake8-simplify
|
||||
]
|
||||
ignore = [
|
||||
"E501", # line too long (handled by formatter)
|
||||
"B008", # function call in default argument (Click uses this)
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["notebooklm"]
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "double"
|
||||
indent-style = "space"
|
||||
|
|
|
|||
|
|
@ -16,52 +16,52 @@ Note:
|
|||
__version__ = "0.1.1"
|
||||
|
||||
# Public API: Authentication
|
||||
from .auth import AuthTokens, DEFAULT_STORAGE_PATH
|
||||
from .auth import DEFAULT_STORAGE_PATH, AuthTokens
|
||||
|
||||
# Public API: Client
|
||||
from .client import NotebookLMClient
|
||||
|
||||
# Public API: Types and dataclasses
|
||||
from .types import (
|
||||
Notebook,
|
||||
NotebookDescription,
|
||||
SuggestedTopic,
|
||||
Source,
|
||||
Artifact,
|
||||
GenerationStatus,
|
||||
ReportSuggestion,
|
||||
Note,
|
||||
ConversationTurn,
|
||||
AskResult,
|
||||
ChatMode,
|
||||
# Exceptions
|
||||
SourceError,
|
||||
SourceProcessingError,
|
||||
SourceTimeoutError,
|
||||
SourceNotFoundError,
|
||||
# Enums for configuration
|
||||
StudioContentType,
|
||||
AudioFormat,
|
||||
AudioLength,
|
||||
VideoFormat,
|
||||
VideoStyle,
|
||||
QuizQuantity,
|
||||
QuizDifficulty,
|
||||
InfographicOrientation,
|
||||
InfographicDetail,
|
||||
SlideDeckFormat,
|
||||
SlideDeckLength,
|
||||
ReportFormat,
|
||||
ChatGoal,
|
||||
ChatResponseLength,
|
||||
DriveMimeType,
|
||||
ExportType,
|
||||
SourceStatus,
|
||||
)
|
||||
|
||||
# Public API: RPC errors (needed for exception handling)
|
||||
from .rpc import RPCError
|
||||
|
||||
# Public API: Types and dataclasses
|
||||
from .types import (
|
||||
Artifact,
|
||||
AskResult,
|
||||
AudioFormat,
|
||||
AudioLength,
|
||||
ChatGoal,
|
||||
ChatMode,
|
||||
ChatResponseLength,
|
||||
ConversationTurn,
|
||||
DriveMimeType,
|
||||
ExportType,
|
||||
GenerationStatus,
|
||||
InfographicDetail,
|
||||
InfographicOrientation,
|
||||
Note,
|
||||
Notebook,
|
||||
NotebookDescription,
|
||||
QuizDifficulty,
|
||||
QuizQuantity,
|
||||
ReportFormat,
|
||||
ReportSuggestion,
|
||||
SlideDeckFormat,
|
||||
SlideDeckLength,
|
||||
Source,
|
||||
# Exceptions
|
||||
SourceError,
|
||||
SourceNotFoundError,
|
||||
SourceProcessingError,
|
||||
SourceStatus,
|
||||
SourceTimeoutError,
|
||||
# Enums for configuration
|
||||
StudioContentType,
|
||||
SuggestedTopic,
|
||||
VideoFormat,
|
||||
VideoStyle,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
# Client (main entry point)
|
||||
|
|
|
|||
|
|
@ -6,30 +6,31 @@ Quizzes, Flashcards, Infographics, Slide Decks, Data Tables, and Mind Maps.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
|
||||
from ._core import ClientCore
|
||||
from .auth import load_httpx_cookies
|
||||
from .rpc import (
|
||||
RPCMethod,
|
||||
RPCError,
|
||||
StudioContentType,
|
||||
ArtifactStatus,
|
||||
AudioFormat,
|
||||
AudioLength,
|
||||
VideoFormat,
|
||||
VideoStyle,
|
||||
QuizQuantity,
|
||||
QuizDifficulty,
|
||||
InfographicOrientation,
|
||||
ExportType,
|
||||
InfographicDetail,
|
||||
InfographicOrientation,
|
||||
QuizDifficulty,
|
||||
QuizQuantity,
|
||||
ReportFormat,
|
||||
RPCError,
|
||||
RPCMethod,
|
||||
SlideDeckFormat,
|
||||
SlideDeckLength,
|
||||
ReportFormat,
|
||||
ExportType,
|
||||
StudioContentType,
|
||||
VideoFormat,
|
||||
VideoStyle,
|
||||
)
|
||||
from .types import Artifact, GenerationStatus, ReportSuggestion
|
||||
|
||||
|
|
@ -73,9 +74,7 @@ class ArtifactsAPI:
|
|||
# List/Get Operations
|
||||
# =========================================================================
|
||||
|
||||
async def list(
|
||||
self, notebook_id: str, artifact_type: Optional[int] = None
|
||||
) -> List[Artifact]:
|
||||
async def list(self, notebook_id: str, artifact_type: int | None = None) -> list[Artifact]:
|
||||
"""List all artifacts in a notebook, including mind maps.
|
||||
|
||||
This returns all AI-generated content: Audio Overviews, Video Overviews,
|
||||
|
|
@ -93,7 +92,7 @@ class ArtifactsAPI:
|
|||
Returns:
|
||||
List of Artifact objects.
|
||||
"""
|
||||
artifacts: List[Artifact] = []
|
||||
artifacts: list[Artifact] = []
|
||||
|
||||
# Fetch studio artifacts (audio, video, reports, etc.)
|
||||
params = [[2], notebook_id, 'NOT artifact.status = "ARTIFACT_STATUS_SUGGESTED"']
|
||||
|
|
@ -121,7 +120,10 @@ class ArtifactsAPI:
|
|||
for mm_data in mind_maps:
|
||||
mind_map_artifact = Artifact.from_mind_map(mm_data)
|
||||
if mind_map_artifact is not None: # None means deleted (status=2)
|
||||
if artifact_type is None or mind_map_artifact.artifact_type == artifact_type:
|
||||
if (
|
||||
artifact_type is None
|
||||
or mind_map_artifact.artifact_type == artifact_type
|
||||
):
|
||||
artifacts.append(mind_map_artifact)
|
||||
except (RPCError, httpx.HTTPError) as e:
|
||||
# Network/API errors - log and continue with studio artifacts
|
||||
|
|
@ -131,7 +133,7 @@ class ArtifactsAPI:
|
|||
|
||||
return artifacts
|
||||
|
||||
async def get(self, notebook_id: str, artifact_id: str) -> Optional[Artifact]:
|
||||
async def get(self, notebook_id: str, artifact_id: str) -> Artifact | None:
|
||||
"""Get a specific artifact by ID.
|
||||
|
||||
Args:
|
||||
|
|
@ -147,37 +149,37 @@ class ArtifactsAPI:
|
|||
return artifact
|
||||
return None
|
||||
|
||||
async def list_audio(self, notebook_id: str) -> List[Artifact]:
|
||||
async def list_audio(self, notebook_id: str) -> builtins.list[Artifact]:
|
||||
"""List audio overview artifacts."""
|
||||
return await self.list(notebook_id, StudioContentType.AUDIO.value)
|
||||
|
||||
async def list_video(self, notebook_id: str) -> List[Artifact]:
|
||||
async def list_video(self, notebook_id: str) -> builtins.list[Artifact]:
|
||||
"""List video overview artifacts."""
|
||||
return await self.list(notebook_id, StudioContentType.VIDEO.value)
|
||||
|
||||
async def list_reports(self, notebook_id: str) -> List[Artifact]:
|
||||
async def list_reports(self, notebook_id: str) -> builtins.list[Artifact]:
|
||||
"""List report artifacts (Briefing Doc, Study Guide, Blog Post)."""
|
||||
return await self.list(notebook_id, StudioContentType.REPORT.value)
|
||||
|
||||
async def list_quizzes(self, notebook_id: str) -> List[Artifact]:
|
||||
async def list_quizzes(self, notebook_id: str) -> builtins.list[Artifact]:
|
||||
"""List quiz artifacts (type 4 with variant 2)."""
|
||||
all_type4 = await self.list(notebook_id, StudioContentType.QUIZ_FLASHCARD.value)
|
||||
return [a for a in all_type4 if a.is_quiz]
|
||||
|
||||
async def list_flashcards(self, notebook_id: str) -> List[Artifact]:
|
||||
async def list_flashcards(self, notebook_id: str) -> builtins.list[Artifact]:
|
||||
"""List flashcard artifacts (type 4 with variant 1)."""
|
||||
all_type4 = await self.list(notebook_id, StudioContentType.QUIZ_FLASHCARD.value)
|
||||
return [a for a in all_type4 if a.is_flashcards]
|
||||
|
||||
async def list_infographics(self, notebook_id: str) -> List[Artifact]:
|
||||
async def list_infographics(self, notebook_id: str) -> builtins.list[Artifact]:
|
||||
"""List infographic artifacts."""
|
||||
return await self.list(notebook_id, StudioContentType.INFOGRAPHIC.value)
|
||||
|
||||
async def list_slide_decks(self, notebook_id: str) -> List[Artifact]:
|
||||
async def list_slide_decks(self, notebook_id: str) -> builtins.list[Artifact]:
|
||||
"""List slide deck artifacts."""
|
||||
return await self.list(notebook_id, StudioContentType.SLIDE_DECK.value)
|
||||
|
||||
async def list_data_tables(self, notebook_id: str) -> List[Artifact]:
|
||||
async def list_data_tables(self, notebook_id: str) -> builtins.list[Artifact]:
|
||||
"""List data table artifacts."""
|
||||
return await self.list(notebook_id, StudioContentType.DATA_TABLE.value)
|
||||
|
||||
|
|
@ -188,11 +190,11 @@ class ArtifactsAPI:
|
|||
async def generate_audio(
|
||||
self,
|
||||
notebook_id: str,
|
||||
source_ids: Optional[List[str]] = None,
|
||||
source_ids: builtins.list[str] | None = None,
|
||||
language: str = "en",
|
||||
instructions: Optional[str] = None,
|
||||
audio_format: Optional[AudioFormat] = None,
|
||||
audio_length: Optional[AudioLength] = None,
|
||||
instructions: str | None = None,
|
||||
audio_format: AudioFormat | None = None,
|
||||
audio_length: AudioLength | None = None,
|
||||
) -> GenerationStatus:
|
||||
"""Generate an Audio Overview (podcast).
|
||||
|
||||
|
|
@ -245,11 +247,11 @@ class ArtifactsAPI:
|
|||
async def generate_video(
|
||||
self,
|
||||
notebook_id: str,
|
||||
source_ids: Optional[List[str]] = None,
|
||||
source_ids: builtins.list[str] | None = None,
|
||||
language: str = "en",
|
||||
instructions: Optional[str] = None,
|
||||
video_format: Optional[VideoFormat] = None,
|
||||
video_style: Optional[VideoStyle] = None,
|
||||
instructions: str | None = None,
|
||||
video_format: VideoFormat | None = None,
|
||||
video_style: VideoStyle | None = None,
|
||||
) -> GenerationStatus:
|
||||
"""Generate a Video Overview.
|
||||
|
||||
|
|
@ -305,9 +307,9 @@ class ArtifactsAPI:
|
|||
self,
|
||||
notebook_id: str,
|
||||
report_format: ReportFormat = ReportFormat.BRIEFING_DOC,
|
||||
source_ids: Optional[List[str]] = None,
|
||||
source_ids: builtins.list[str] | None = None,
|
||||
language: str = "en",
|
||||
custom_prompt: Optional[str] = None,
|
||||
custom_prompt: str | None = None,
|
||||
) -> GenerationStatus:
|
||||
"""Generate a report artifact.
|
||||
|
||||
|
|
@ -395,7 +397,7 @@ class ArtifactsAPI:
|
|||
async def generate_study_guide(
|
||||
self,
|
||||
notebook_id: str,
|
||||
source_ids: Optional[List[str]] = None,
|
||||
source_ids: builtins.list[str] | None = None,
|
||||
language: str = "en",
|
||||
) -> GenerationStatus:
|
||||
"""Generate a study guide report.
|
||||
|
|
@ -420,10 +422,10 @@ class ArtifactsAPI:
|
|||
async def generate_quiz(
|
||||
self,
|
||||
notebook_id: str,
|
||||
source_ids: Optional[List[str]] = None,
|
||||
instructions: Optional[str] = None,
|
||||
quantity: Optional[QuizQuantity] = None,
|
||||
difficulty: Optional[QuizDifficulty] = None,
|
||||
source_ids: builtins.list[str] | None = None,
|
||||
instructions: str | None = None,
|
||||
quantity: QuizQuantity | None = None,
|
||||
difficulty: QuizDifficulty | None = None,
|
||||
) -> GenerationStatus:
|
||||
"""Generate a quiz.
|
||||
|
||||
|
|
@ -477,10 +479,10 @@ class ArtifactsAPI:
|
|||
async def generate_flashcards(
|
||||
self,
|
||||
notebook_id: str,
|
||||
source_ids: Optional[List[str]] = None,
|
||||
instructions: Optional[str] = None,
|
||||
quantity: Optional[QuizQuantity] = None,
|
||||
difficulty: Optional[QuizDifficulty] = None,
|
||||
source_ids: builtins.list[str] | None = None,
|
||||
instructions: str | None = None,
|
||||
quantity: QuizQuantity | None = None,
|
||||
difficulty: QuizDifficulty | None = None,
|
||||
) -> GenerationStatus:
|
||||
"""Generate flashcards.
|
||||
|
||||
|
|
@ -533,11 +535,11 @@ class ArtifactsAPI:
|
|||
async def generate_infographic(
|
||||
self,
|
||||
notebook_id: str,
|
||||
source_ids: Optional[List[str]] = None,
|
||||
source_ids: builtins.list[str] | None = None,
|
||||
language: str = "en",
|
||||
instructions: Optional[str] = None,
|
||||
orientation: Optional[InfographicOrientation] = None,
|
||||
detail_level: Optional[InfographicDetail] = None,
|
||||
instructions: str | None = None,
|
||||
orientation: InfographicOrientation | None = None,
|
||||
detail_level: InfographicDetail | None = None,
|
||||
) -> GenerationStatus:
|
||||
"""Generate an infographic.
|
||||
|
||||
|
|
@ -585,11 +587,11 @@ class ArtifactsAPI:
|
|||
async def generate_slide_deck(
|
||||
self,
|
||||
notebook_id: str,
|
||||
source_ids: Optional[List[str]] = None,
|
||||
source_ids: builtins.list[str] | None = None,
|
||||
language: str = "en",
|
||||
instructions: Optional[str] = None,
|
||||
slide_format: Optional[SlideDeckFormat] = None,
|
||||
slide_length: Optional[SlideDeckLength] = None,
|
||||
instructions: str | None = None,
|
||||
slide_format: SlideDeckFormat | None = None,
|
||||
slide_length: SlideDeckLength | None = None,
|
||||
) -> GenerationStatus:
|
||||
"""Generate a slide deck.
|
||||
|
||||
|
|
@ -639,9 +641,9 @@ class ArtifactsAPI:
|
|||
async def generate_data_table(
|
||||
self,
|
||||
notebook_id: str,
|
||||
source_ids: Optional[List[str]] = None,
|
||||
source_ids: builtins.list[str] | None = None,
|
||||
language: str = "en",
|
||||
instructions: Optional[str] = None,
|
||||
instructions: str | None = None,
|
||||
) -> GenerationStatus:
|
||||
"""Generate a data table.
|
||||
|
||||
|
|
@ -689,7 +691,7 @@ class ArtifactsAPI:
|
|||
async def generate_mind_map(
|
||||
self,
|
||||
notebook_id: str,
|
||||
source_ids: Optional[List[str]] = None,
|
||||
source_ids: builtins.list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Generate an interactive mind map.
|
||||
|
||||
|
|
@ -766,7 +768,7 @@ class ArtifactsAPI:
|
|||
# =========================================================================
|
||||
|
||||
async def download_audio(
|
||||
self, notebook_id: str, output_path: str, artifact_id: Optional[str] = None
|
||||
self, notebook_id: str, output_path: str, artifact_id: str | None = None
|
||||
) -> str:
|
||||
"""Download an Audio Overview to a file.
|
||||
|
||||
|
|
@ -782,8 +784,10 @@ class ArtifactsAPI:
|
|||
|
||||
# Filter for completed audio artifacts
|
||||
audio_candidates = [
|
||||
a for a in artifacts_data
|
||||
if isinstance(a, list) and len(a) > 4
|
||||
a
|
||||
for a in artifacts_data
|
||||
if isinstance(a, list)
|
||||
and len(a) > 4
|
||||
and a[2] == StudioContentType.AUDIO
|
||||
and a[4] == ArtifactStatus.COMPLETED
|
||||
]
|
||||
|
|
@ -826,7 +830,7 @@ class ArtifactsAPI:
|
|||
raise ValueError(f"Failed to parse audio artifact structure: {e}")
|
||||
|
||||
async def download_video(
|
||||
self, notebook_id: str, output_path: str, artifact_id: Optional[str] = None
|
||||
self, notebook_id: str, output_path: str, artifact_id: str | None = None
|
||||
) -> str:
|
||||
"""Download a Video Overview to a file.
|
||||
|
||||
|
|
@ -842,8 +846,10 @@ class ArtifactsAPI:
|
|||
|
||||
# Filter for completed video artifacts
|
||||
video_candidates = [
|
||||
a for a in artifacts_data
|
||||
if isinstance(a, list) and len(a) > 4
|
||||
a
|
||||
for a in artifacts_data
|
||||
if isinstance(a, list)
|
||||
and len(a) > 4
|
||||
and a[2] == StudioContentType.VIDEO
|
||||
and a[4] == ArtifactStatus.COMPLETED
|
||||
]
|
||||
|
|
@ -902,7 +908,7 @@ class ArtifactsAPI:
|
|||
raise ValueError(f"Failed to parse video artifact structure: {e}")
|
||||
|
||||
async def download_infographic(
|
||||
self, notebook_id: str, output_path: str, artifact_id: Optional[str] = None
|
||||
self, notebook_id: str, output_path: str, artifact_id: str | None = None
|
||||
) -> str:
|
||||
"""Download an Infographic to a file.
|
||||
|
||||
|
|
@ -918,8 +924,10 @@ class ArtifactsAPI:
|
|||
|
||||
# Filter for completed infographic artifacts
|
||||
info_candidates = [
|
||||
a for a in artifacts_data
|
||||
if isinstance(a, list) and len(a) > 4
|
||||
a
|
||||
for a in artifacts_data
|
||||
if isinstance(a, list)
|
||||
and len(a) > 4
|
||||
and a[2] == StudioContentType.INFOGRAPHIC
|
||||
and a[4] == ArtifactStatus.COMPLETED
|
||||
]
|
||||
|
|
@ -962,7 +970,7 @@ class ArtifactsAPI:
|
|||
raise ValueError(f"Failed to parse infographic structure: {e}")
|
||||
|
||||
async def download_slide_deck(
|
||||
self, notebook_id: str, output_path: str, artifact_id: Optional[str] = None
|
||||
self, notebook_id: str, output_path: str, artifact_id: str | None = None
|
||||
) -> str:
|
||||
"""Download a slide deck as a PDF file.
|
||||
|
||||
|
|
@ -978,8 +986,10 @@ class ArtifactsAPI:
|
|||
|
||||
# Filter for completed slide deck artifacts
|
||||
slide_candidates = [
|
||||
a for a in artifacts_data
|
||||
if isinstance(a, list) and len(a) > 4
|
||||
a
|
||||
for a in artifacts_data
|
||||
if isinstance(a, list)
|
||||
and len(a) > 4
|
||||
and a[2] == StudioContentType.SLIDE_DECK
|
||||
and a[4] == ArtifactStatus.COMPLETED
|
||||
]
|
||||
|
|
@ -1036,9 +1046,7 @@ class ArtifactsAPI:
|
|||
)
|
||||
return True
|
||||
|
||||
async def rename(
|
||||
self, notebook_id: str, artifact_id: str, new_title: str
|
||||
) -> None:
|
||||
async def rename(self, notebook_id: str, artifact_id: str, new_title: str) -> None:
|
||||
"""Rename an artifact.
|
||||
|
||||
Args:
|
||||
|
|
@ -1095,7 +1103,7 @@ class ArtifactsAPI:
|
|||
initial_interval: float = 2.0,
|
||||
max_interval: float = 10.0,
|
||||
timeout: float = 300.0,
|
||||
poll_interval: Optional[float] = None, # Deprecated, use initial_interval
|
||||
poll_interval: float | None = None, # Deprecated, use initial_interval
|
||||
) -> GenerationStatus:
|
||||
"""Wait for a generation task to complete.
|
||||
|
||||
|
|
@ -1118,6 +1126,7 @@ class ArtifactsAPI:
|
|||
# Backward compatibility: poll_interval overrides initial_interval
|
||||
if poll_interval is not None:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"poll_interval is deprecated, use initial_interval instead",
|
||||
DeprecationWarning,
|
||||
|
|
@ -1204,8 +1213,8 @@ class ArtifactsAPI:
|
|||
async def export(
|
||||
self,
|
||||
notebook_id: str,
|
||||
artifact_id: Optional[str] = None,
|
||||
content: Optional[str] = None,
|
||||
artifact_id: str | None = None,
|
||||
content: str | None = None,
|
||||
title: str = "Export",
|
||||
export_type: ExportType = ExportType.DOCS,
|
||||
) -> Any:
|
||||
|
|
@ -1238,8 +1247,8 @@ class ArtifactsAPI:
|
|||
async def suggest_reports(
|
||||
self,
|
||||
notebook_id: str,
|
||||
source_ids: Optional[List[str]] = None,
|
||||
) -> List[ReportSuggestion]:
|
||||
source_ids: builtins.list[str] | None = None,
|
||||
) -> builtins.list[ReportSuggestion]:
|
||||
"""Get AI-suggested report formats for a notebook.
|
||||
|
||||
Args:
|
||||
|
|
@ -1276,12 +1285,14 @@ class ArtifactsAPI:
|
|||
if result and isinstance(result, list):
|
||||
for item in result:
|
||||
if isinstance(item, list) and len(item) >= 5:
|
||||
suggestions.append(ReportSuggestion(
|
||||
title=item[0] if isinstance(item[0], str) else "",
|
||||
description=item[1] if isinstance(item[1], str) else "",
|
||||
prompt=item[4] if len(item) > 4 and isinstance(item[4], str) else "",
|
||||
audience_level=item[5] if len(item) > 5 else 2,
|
||||
))
|
||||
suggestions.append(
|
||||
ReportSuggestion(
|
||||
title=item[0] if isinstance(item[0], str) else "",
|
||||
description=item[1] if isinstance(item[1], str) else "",
|
||||
prompt=item[4] if len(item) > 4 and isinstance(item[4], str) else "",
|
||||
audience_level=item[5] if len(item) > 5 else 2,
|
||||
)
|
||||
)
|
||||
|
||||
return suggestions
|
||||
|
||||
|
|
@ -1289,7 +1300,9 @@ class ArtifactsAPI:
|
|||
# Private Helpers
|
||||
# =========================================================================
|
||||
|
||||
async def _call_generate(self, notebook_id: str, params: List[Any]) -> GenerationStatus:
|
||||
async def _call_generate(
|
||||
self, notebook_id: str, params: builtins.list[Any]
|
||||
) -> GenerationStatus:
|
||||
"""Make a generation RPC call with error handling.
|
||||
|
||||
Wraps the RPC call to handle UserDisplayableError (rate limiting/quota)
|
||||
|
|
@ -1320,7 +1333,7 @@ class ArtifactsAPI:
|
|||
)
|
||||
raise
|
||||
|
||||
async def _list_raw(self, notebook_id: str) -> List[Any]:
|
||||
async def _list_raw(self, notebook_id: str) -> builtins.list[Any]:
|
||||
"""Get raw artifact list data."""
|
||||
params = [[2], notebook_id, 'NOT artifact.status = "ARTIFACT_STATUS_SUGGESTED"']
|
||||
result = await self._core.rpc_call(
|
||||
|
|
@ -1333,7 +1346,7 @@ class ArtifactsAPI:
|
|||
return result[0] if isinstance(result[0], list) else result
|
||||
return []
|
||||
|
||||
async def _get_source_ids(self, notebook_id: str) -> List[str]:
|
||||
async def _get_source_ids(self, notebook_id: str) -> builtins.list[str]:
|
||||
"""Extract source IDs from notebook data."""
|
||||
params = [notebook_id, None, [2], None, 0]
|
||||
notebook_data = await self._core.rpc_call(
|
||||
|
|
@ -1364,8 +1377,8 @@ class ArtifactsAPI:
|
|||
return source_ids
|
||||
|
||||
async def _download_urls_batch(
|
||||
self, urls_and_paths: List[Tuple[str, str]]
|
||||
) -> List[str]:
|
||||
self, urls_and_paths: builtins.list[tuple[str, str]]
|
||||
) -> builtins.list[str]:
|
||||
"""Download multiple files using httpx with proper cookie handling.
|
||||
|
||||
Args:
|
||||
|
|
@ -1376,7 +1389,7 @@ class ArtifactsAPI:
|
|||
"""
|
||||
from pathlib import Path
|
||||
|
||||
downloaded: List[str] = []
|
||||
downloaded: list[str] = []
|
||||
|
||||
# Load cookies with domain info for cross-domain redirect handling
|
||||
cookies = load_httpx_cookies()
|
||||
|
|
@ -1450,11 +1463,27 @@ class ArtifactsAPI:
|
|||
"""Parse generation API result into GenerationStatus."""
|
||||
if result and isinstance(result, list) and len(result) > 0:
|
||||
artifact_data = result[0]
|
||||
artifact_id = artifact_data[0] if isinstance(artifact_data, list) and len(artifact_data) > 0 else None
|
||||
status_code = artifact_data[4] if isinstance(artifact_data, list) and len(artifact_data) > 4 else None
|
||||
artifact_id = (
|
||||
artifact_data[0]
|
||||
if isinstance(artifact_data, list) and len(artifact_data) > 0
|
||||
else None
|
||||
)
|
||||
status_code = (
|
||||
artifact_data[4]
|
||||
if isinstance(artifact_data, list) and len(artifact_data) > 4
|
||||
else None
|
||||
)
|
||||
|
||||
if artifact_id:
|
||||
status = "in_progress" if status_code == 1 else "completed" if status_code == 3 else "pending"
|
||||
status = (
|
||||
"in_progress"
|
||||
if status_code == 1
|
||||
else "completed"
|
||||
if status_code == 3
|
||||
else "pending"
|
||||
)
|
||||
return GenerationStatus(task_id=artifact_id, status=status)
|
||||
|
||||
return GenerationStatus(task_id="", status="failed", error="Generation failed - no artifact_id returned")
|
||||
return GenerationStatus(
|
||||
task_id="", status="failed", error="Generation failed - no artifact_id returned"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,11 +8,11 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urlencode, quote
|
||||
from typing import Any
|
||||
from urllib.parse import quote, urlencode
|
||||
|
||||
from ._core import ClientCore
|
||||
from .rpc import RPCMethod, QUERY_URL
|
||||
from .rpc import QUERY_URL, RPCMethod
|
||||
from .types import AskResult, ConversationTurn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -50,8 +50,8 @@ class ChatAPI:
|
|||
self,
|
||||
notebook_id: str,
|
||||
question: str,
|
||||
source_ids: Optional[list[str]] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
source_ids: list[str] | None = None,
|
||||
conversation_id: str | None = None,
|
||||
) -> AskResult:
|
||||
"""Ask the notebook a question.
|
||||
|
||||
|
|
@ -114,9 +114,7 @@ class ChatAPI:
|
|||
|
||||
self._core._reqid_counter += 100000
|
||||
url_params = {
|
||||
"bl": os.environ.get(
|
||||
"NOTEBOOKLM_BL", "boq_labs-tailwind-frontend_20251221.14_p0"
|
||||
),
|
||||
"bl": os.environ.get("NOTEBOOKLM_BL", "boq_labs-tailwind-frontend_20251221.14_p0"),
|
||||
"hl": "en",
|
||||
"_reqid": str(self._core._reqid_counter),
|
||||
"rt": "c",
|
||||
|
|
@ -136,9 +134,7 @@ class ChatAPI:
|
|||
if answer_text:
|
||||
turns = self._core.get_cached_conversation(conversation_id)
|
||||
turn_number = len(turns) + 1
|
||||
self._core.cache_conversation_turn(
|
||||
conversation_id, question, answer_text, turn_number
|
||||
)
|
||||
self._core.cache_conversation_turn(conversation_id, question, answer_text, turn_number)
|
||||
else:
|
||||
turns = self._core.get_cached_conversation(conversation_id)
|
||||
turn_number = len(turns)
|
||||
|
|
@ -187,7 +183,7 @@ class ChatAPI:
|
|||
for turn in cached
|
||||
]
|
||||
|
||||
def clear_cache(self, conversation_id: Optional[str] = None) -> bool:
|
||||
def clear_cache(self, conversation_id: str | None = None) -> bool:
|
||||
"""Clear conversation cache.
|
||||
|
||||
Args:
|
||||
|
|
@ -201,9 +197,9 @@ class ChatAPI:
|
|||
async def configure(
|
||||
self,
|
||||
notebook_id: str,
|
||||
goal: Optional[Any] = None,
|
||||
response_length: Optional[Any] = None,
|
||||
custom_prompt: Optional[str] = None,
|
||||
goal: Any | None = None,
|
||||
response_length: Any | None = None,
|
||||
custom_prompt: str | None = None,
|
||||
) -> None:
|
||||
"""Configure chat persona and response settings for a notebook.
|
||||
|
||||
|
|
@ -298,7 +294,7 @@ class ChatAPI:
|
|||
|
||||
return source_ids
|
||||
|
||||
def _build_conversation_history(self, conversation_id: str) -> Optional[list]:
|
||||
def _build_conversation_history(self, conversation_id: str) -> list | None:
|
||||
"""Build conversation history for follow-up requests."""
|
||||
turns = self._core.get_cached_conversation(conversation_id)
|
||||
if not turns:
|
||||
|
|
@ -347,7 +343,7 @@ class ChatAPI:
|
|||
)
|
||||
return longest_answer
|
||||
|
||||
def _extract_answer_from_chunk(self, json_str: str) -> tuple[Optional[str], bool]:
|
||||
def _extract_answer_from_chunk(self, json_str: str) -> tuple[str | None, bool]:
|
||||
"""Extract answer text from a response chunk."""
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
|
|
|
|||
|
|
@ -3,19 +3,19 @@
|
|||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
|
||||
from .auth import AuthTokens
|
||||
from .rpc import (
|
||||
RPCMethod,
|
||||
RPCError,
|
||||
BATCHEXECUTE_URL,
|
||||
encode_rpc_request,
|
||||
RPCError,
|
||||
RPCMethod,
|
||||
build_request_body,
|
||||
decode_response,
|
||||
encode_rpc_request,
|
||||
)
|
||||
|
||||
# Enable RPC debug output via environment variable
|
||||
|
|
@ -57,7 +57,7 @@ class ClientCore:
|
|||
"""
|
||||
self.auth = auth
|
||||
self._timeout = timeout
|
||||
self._http_client: Optional[httpx.AsyncClient] = None
|
||||
self._http_client: httpx.AsyncClient | None = None
|
||||
# Request ID counter for chat API (must be unique per request)
|
||||
self._reqid_counter: int = 100000
|
||||
# OrderedDict for FIFO eviction when cache exceeds MAX_CONVERSATION_CACHE_SIZE
|
||||
|
|
@ -216,11 +216,13 @@ class ClientCore:
|
|||
self._conversation_cache.popitem(last=False)
|
||||
self._conversation_cache[conversation_id] = []
|
||||
|
||||
self._conversation_cache[conversation_id].append({
|
||||
"query": query,
|
||||
"answer": answer,
|
||||
"turn_number": turn_number,
|
||||
})
|
||||
self._conversation_cache[conversation_id].append(
|
||||
{
|
||||
"query": query,
|
||||
"answer": answer,
|
||||
"turn_number": turn_number,
|
||||
}
|
||||
)
|
||||
|
||||
def get_cached_conversation(self, conversation_id: str) -> list[dict[str, Any]]:
|
||||
"""Get cached conversation turns.
|
||||
|
|
@ -233,7 +235,7 @@ class ClientCore:
|
|||
"""
|
||||
return self._conversation_cache.get(conversation_id, [])
|
||||
|
||||
def clear_conversation_cache(self, conversation_id: Optional[str] = None) -> bool:
|
||||
def clear_conversation_cache(self, conversation_id: str | None = None) -> bool:
|
||||
"""Clear conversation cache.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Notebook operations API."""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from ._core import ClientCore
|
||||
from .rpc import RPCMethod
|
||||
|
|
@ -169,10 +169,12 @@ class NotebooksAPI:
|
|||
topics_list = result[1][0] if isinstance(result[1][0], list) else []
|
||||
for topic in topics_list:
|
||||
if isinstance(topic, list) and len(topic) >= 2:
|
||||
suggested_topics.append(SuggestedTopic(
|
||||
question=topic[0] if isinstance(topic[0], str) else "",
|
||||
prompt=topic[1] if isinstance(topic[1], str) else "",
|
||||
))
|
||||
suggested_topics.append(
|
||||
SuggestedTopic(
|
||||
question=topic[0] if isinstance(topic[0], str) else "",
|
||||
prompt=topic[1] if isinstance(topic[1], str) else "",
|
||||
)
|
||||
)
|
||||
|
||||
return NotebookDescription(summary=summary, suggested_topics=suggested_topics)
|
||||
|
||||
|
|
@ -209,7 +211,7 @@ class NotebooksAPI:
|
|||
)
|
||||
|
||||
async def share(
|
||||
self, notebook_id: str, public: bool = True, artifact_id: Optional[str] = None
|
||||
self, notebook_id: str, public: bool = True, artifact_id: str | None = None
|
||||
) -> dict:
|
||||
"""Toggle notebook sharing.
|
||||
|
||||
|
|
@ -252,9 +254,7 @@ class NotebooksAPI:
|
|||
"artifact_id": artifact_id,
|
||||
}
|
||||
|
||||
def get_share_url(
|
||||
self, notebook_id: str, artifact_id: Optional[str] = None
|
||||
) -> str:
|
||||
def get_share_url(self, notebook_id: str, artifact_id: str | None = None) -> str:
|
||||
"""Get share URL for a notebook or artifact.
|
||||
|
||||
This does NOT toggle sharing - it just returns the URL format.
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ user-created notes in notebooks. Notes are distinct from artifacts -
|
|||
they are user-created content, not AI-generated.
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
import builtins
|
||||
from typing import Any
|
||||
|
||||
from ._core import ClientCore
|
||||
from .rpc import RPCMethod
|
||||
|
|
@ -37,7 +38,7 @@ class NotesAPI:
|
|||
"""
|
||||
self._core = core
|
||||
|
||||
async def list(self, notebook_id: str) -> List[Note]:
|
||||
async def list(self, notebook_id: str) -> list[Note]:
|
||||
"""List all text notes in the notebook.
|
||||
|
||||
This excludes:
|
||||
|
|
@ -59,15 +60,13 @@ class NotesAPI:
|
|||
continue
|
||||
|
||||
content = self._extract_content(item)
|
||||
is_mind_map = content and (
|
||||
'"children":' in content or '"nodes":' in content
|
||||
)
|
||||
is_mind_map = content and ('"children":' in content or '"nodes":' in content)
|
||||
if not is_mind_map:
|
||||
notes.append(self._parse_note(item, notebook_id))
|
||||
|
||||
return notes
|
||||
|
||||
async def get(self, notebook_id: str, note_id: str) -> Optional[Note]:
|
||||
async def get(self, notebook_id: str, note_id: str) -> Note | None:
|
||||
"""Get a specific note by ID.
|
||||
|
||||
Args:
|
||||
|
|
@ -173,7 +172,7 @@ class NotesAPI:
|
|||
)
|
||||
return True
|
||||
|
||||
async def list_mind_maps(self, notebook_id: str) -> List[Any]:
|
||||
async def list_mind_maps(self, notebook_id: str) -> builtins.list[Any]:
|
||||
"""List all mind maps in the notebook.
|
||||
|
||||
Mind maps are stored in the same internal structure as notes but
|
||||
|
|
@ -227,7 +226,7 @@ class NotesAPI:
|
|||
# Private Helpers
|
||||
# =========================================================================
|
||||
|
||||
async def _get_all_notes_and_mind_maps(self, notebook_id: str) -> List[Any]:
|
||||
async def _get_all_notes_and_mind_maps(self, notebook_id: str) -> builtins.list[Any]:
|
||||
"""Fetch all notes and mind maps from the API."""
|
||||
params = [notebook_id]
|
||||
result = await self._core.rpc_call(
|
||||
|
|
@ -236,25 +235,16 @@ class NotesAPI:
|
|||
source_path=f"/notebook/{notebook_id}",
|
||||
allow_null=True,
|
||||
)
|
||||
if (
|
||||
result
|
||||
and isinstance(result, list)
|
||||
and len(result) > 0
|
||||
and isinstance(result[0], list)
|
||||
):
|
||||
if result and isinstance(result, list) and len(result) > 0 and isinstance(result[0], list):
|
||||
notes_list = result[0]
|
||||
valid_notes = []
|
||||
for item in notes_list:
|
||||
if (
|
||||
isinstance(item, list)
|
||||
and len(item) > 0
|
||||
and isinstance(item[0], str)
|
||||
):
|
||||
if isinstance(item, list) and len(item) > 0 and isinstance(item[0], str):
|
||||
valid_notes.append(item)
|
||||
return valid_notes
|
||||
return []
|
||||
|
||||
def _is_deleted(self, item: List[Any]) -> bool:
|
||||
def _is_deleted(self, item: builtins.list[Any]) -> bool:
|
||||
"""Check if a note/mind map item is deleted (status=2).
|
||||
|
||||
Deleted items have structure: ['id', None, 2]
|
||||
|
|
@ -270,7 +260,7 @@ class NotesAPI:
|
|||
return False
|
||||
return item[1] is None and item[2] == 2
|
||||
|
||||
def _extract_content(self, item: List[Any]) -> Optional[str]:
|
||||
def _extract_content(self, item: builtins.list[Any]) -> str | None:
|
||||
"""Extract content string from note/mind map item."""
|
||||
if len(item) <= 1:
|
||||
return None
|
||||
|
|
@ -281,7 +271,7 @@ class NotesAPI:
|
|||
return item[1][1]
|
||||
return None
|
||||
|
||||
def _parse_note(self, item: List[Any], notebook_id: str) -> Note:
|
||||
def _parse_note(self, item: builtins.list[Any], notebook_id: str) -> Note:
|
||||
"""Parse a raw note item into a Note object."""
|
||||
note_id = item[0] if len(item) > 0 else ""
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Provides operations for starting research sessions, polling for results,
|
|||
and importing discovered sources into notebooks.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from ._core import ClientCore
|
||||
from .rpc import RPCMethod
|
||||
|
|
@ -44,7 +44,7 @@ class ResearchAPI:
|
|||
query: str,
|
||||
source: str = "web",
|
||||
mode: str = "fast",
|
||||
) -> Optional[dict[str, Any]]:
|
||||
) -> dict[str, Any] | None:
|
||||
"""Start a research session.
|
||||
|
||||
Args:
|
||||
|
|
@ -117,11 +117,7 @@ class ResearchAPI:
|
|||
return {"status": "no_research"}
|
||||
|
||||
# Unwrap if needed
|
||||
if (
|
||||
isinstance(result[0], list)
|
||||
and len(result[0]) > 0
|
||||
and isinstance(result[0][0], list)
|
||||
):
|
||||
if isinstance(result[0], list) and len(result[0]) > 0 and isinstance(result[0][0], list):
|
||||
result = result[0]
|
||||
|
||||
# Find most recent task
|
||||
|
|
@ -145,13 +141,9 @@ class ResearchAPI:
|
|||
|
||||
if isinstance(sources_and_summary, list) and len(sources_and_summary) >= 1:
|
||||
sources_data = (
|
||||
sources_and_summary[0]
|
||||
if isinstance(sources_and_summary[0], list)
|
||||
else []
|
||||
sources_and_summary[0] if isinstance(sources_and_summary[0], list) else []
|
||||
)
|
||||
if len(sources_and_summary) >= 2 and isinstance(
|
||||
sources_and_summary[1], str
|
||||
):
|
||||
if len(sources_and_summary) >= 2 and isinstance(sources_and_summary[1], str):
|
||||
summary = sources_and_summary[1]
|
||||
|
||||
parsed_sources = []
|
||||
|
|
@ -246,9 +238,7 @@ class ResearchAPI:
|
|||
for src_data in result:
|
||||
if isinstance(src_data, list) and len(src_data) >= 2:
|
||||
src_id = (
|
||||
src_data[0][0]
|
||||
if src_data[0] and isinstance(src_data[0], list)
|
||||
else None
|
||||
src_data[0][0] if src_data[0] and isinstance(src_data[0], list) else None
|
||||
)
|
||||
if src_id:
|
||||
imported.append({"id": src_id, "title": src_data[1]})
|
||||
|
|
|
|||
|
|
@ -1,17 +1,18 @@
|
|||
"""Source operations API."""
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from time import monotonic
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from ._core import ClientCore
|
||||
from .rpc import RPCMethod, UPLOAD_URL
|
||||
from .rpc import UPLOAD_URL, RPCMethod
|
||||
from .rpc.types import SourceStatus
|
||||
from .types import (
|
||||
Source,
|
||||
|
|
@ -91,22 +92,27 @@ class SourcesAPI:
|
|||
if isinstance(url_list, list) and len(url_list) > 0:
|
||||
url = url_list[0]
|
||||
# Detect YouTube vs other URLs
|
||||
if 'youtube.com' in url or 'youtu.be' in url:
|
||||
if "youtube.com" in url or "youtu.be" in url:
|
||||
source_type = "youtube"
|
||||
else:
|
||||
source_type = "url"
|
||||
|
||||
# Extract file info if no URL
|
||||
if not url and title:
|
||||
if title.endswith('.pdf'):
|
||||
if title.endswith(".pdf"):
|
||||
source_type = "pdf"
|
||||
elif title.endswith(('.txt', '.md', '.doc', '.docx')):
|
||||
elif title.endswith((".txt", ".md", ".doc", ".docx")):
|
||||
source_type = "text_file"
|
||||
elif title.endswith(('.xls', '.xlsx', '.csv')):
|
||||
elif title.endswith((".xls", ".xlsx", ".csv")):
|
||||
source_type = "spreadsheet"
|
||||
|
||||
# Check for file upload indicator
|
||||
if source_type == "text" and len(src) > 2 and isinstance(src[2], list) and len(src[2]) > 1:
|
||||
if (
|
||||
source_type == "text"
|
||||
and len(src) > 2
|
||||
and isinstance(src[2], list)
|
||||
and len(src[2]) > 1
|
||||
):
|
||||
if isinstance(src[2][1], int) and src[2][1] > 0:
|
||||
source_type = "upload"
|
||||
|
||||
|
|
@ -132,18 +138,20 @@ class SourcesAPI:
|
|||
):
|
||||
status = status_code
|
||||
|
||||
sources.append(Source(
|
||||
id=str(src_id),
|
||||
title=title,
|
||||
url=url,
|
||||
source_type=source_type,
|
||||
created_at=created_at,
|
||||
status=status,
|
||||
))
|
||||
sources.append(
|
||||
Source(
|
||||
id=str(src_id),
|
||||
title=title,
|
||||
url=url,
|
||||
source_type=source_type,
|
||||
created_at=created_at,
|
||||
status=status,
|
||||
)
|
||||
)
|
||||
|
||||
return sources
|
||||
|
||||
async def get(self, notebook_id: str, source_id: str) -> Optional[Source]:
|
||||
async def get(self, notebook_id: str, source_id: str) -> Source | None:
|
||||
"""Get details of a specific source.
|
||||
|
||||
Args:
|
||||
|
|
@ -200,7 +208,7 @@ class SourcesAPI:
|
|||
"""
|
||||
start = monotonic()
|
||||
interval = initial_interval
|
||||
last_status: Optional[int] = None
|
||||
last_status: int | None = None
|
||||
|
||||
while True:
|
||||
# Check timeout before each poll
|
||||
|
|
@ -233,10 +241,10 @@ class SourcesAPI:
|
|||
async def wait_for_sources(
|
||||
self,
|
||||
notebook_id: str,
|
||||
source_ids: List[str],
|
||||
source_ids: builtins.list[str],
|
||||
timeout: float = 120.0,
|
||||
**kwargs: Any,
|
||||
) -> List[Source]:
|
||||
) -> builtins.list[Source]:
|
||||
"""Wait for multiple sources to become ready in parallel.
|
||||
|
||||
Args:
|
||||
|
|
@ -263,8 +271,7 @@ class SourcesAPI:
|
|||
)
|
||||
"""
|
||||
tasks = [
|
||||
self.wait_until_ready(notebook_id, sid, timeout=timeout, **kwargs)
|
||||
for sid in source_ids
|
||||
self.wait_until_ready(notebook_id, sid, timeout=timeout, **kwargs) for sid in source_ids
|
||||
]
|
||||
return list(await asyncio.gather(*tasks))
|
||||
|
||||
|
|
@ -353,8 +360,8 @@ class SourcesAPI:
|
|||
async def add_file(
|
||||
self,
|
||||
notebook_id: str,
|
||||
file_path: Union[str, Path],
|
||||
mime_type: Optional[str] = None,
|
||||
file_path: str | Path,
|
||||
mime_type: str | None = None,
|
||||
wait: bool = False,
|
||||
wait_timeout: float = 120.0,
|
||||
) -> Source:
|
||||
|
|
@ -397,9 +404,7 @@ class SourcesAPI:
|
|||
source_id = await self._register_file_source(notebook_id, filename)
|
||||
|
||||
# Step 2: Start resumable upload with the SOURCE_ID from step 1
|
||||
upload_url = await self._start_resumable_upload(
|
||||
notebook_id, filename, file_size, source_id
|
||||
)
|
||||
upload_url = await self._start_resumable_upload(notebook_id, filename, file_size, source_id)
|
||||
|
||||
# Step 3: Stream upload file content (memory-efficient)
|
||||
await self._upload_file_streaming(upload_url, file_path)
|
||||
|
|
@ -452,7 +457,15 @@ class SourcesAPI:
|
|||
# Drive source structure: [[file_id, mime_type, 1, title], null x9, 1]
|
||||
source_data = [
|
||||
[file_id, mime_type, 1, title],
|
||||
None, None, None, None, None, None, None, None, None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
1,
|
||||
]
|
||||
params = [
|
||||
|
|
@ -552,7 +565,7 @@ class SourcesAPI:
|
|||
# False means stale, True means fresh
|
||||
return result is True
|
||||
|
||||
async def get_guide(self, notebook_id: str, source_id: str) -> Dict[str, Any]:
|
||||
async def get_guide(self, notebook_id: str, source_id: str) -> dict[str, Any]:
|
||||
"""Get AI-generated summary and keywords for a specific source.
|
||||
|
||||
This is the "Source Guide" feature shown when clicking on a source
|
||||
|
|
@ -596,22 +609,18 @@ class SourcesAPI:
|
|||
# Private helper methods
|
||||
# =========================================================================
|
||||
|
||||
def _extract_youtube_video_id(self, url: str) -> Optional[str]:
|
||||
def _extract_youtube_video_id(self, url: str) -> str | None:
|
||||
"""Extract YouTube video ID from various URL formats."""
|
||||
# Short URLs: youtu.be/VIDEO_ID
|
||||
match = re.match(r"https?://youtu\.be/([a-zA-Z0-9_-]+)", url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
# Standard watch URLs: youtube.com/watch?v=VIDEO_ID
|
||||
match = re.match(
|
||||
r"https?://(?:www\.)?youtube\.com/watch\?v=([a-zA-Z0-9_-]+)", url
|
||||
)
|
||||
match = re.match(r"https?://(?:www\.)?youtube\.com/watch\?v=([a-zA-Z0-9_-]+)", url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
# Shorts URLs: youtube.com/shorts/VIDEO_ID
|
||||
match = re.match(
|
||||
r"https?://(?:www\.)?youtube\.com/shorts/([a-zA-Z0-9_-]+)", url
|
||||
)
|
||||
match = re.match(r"https?://(?:www\.)?youtube\.com/shorts/([a-zA-Z0-9_-]+)", url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
|
@ -666,6 +675,7 @@ class SourcesAPI:
|
|||
# Parse SOURCE_ID from response - handle various nesting formats
|
||||
# API returns different structures: [[[[id]]]], [[[id]]], [[id]], etc.
|
||||
if result and isinstance(result, list):
|
||||
|
||||
def extract_id(data):
|
||||
"""Recursively extract first string from nested lists."""
|
||||
if isinstance(data, str):
|
||||
|
|
@ -704,11 +714,13 @@ class SourcesAPI:
|
|||
"x-goog-upload-protocol": "resumable",
|
||||
}
|
||||
|
||||
body = json.dumps({
|
||||
"PROJECT_ID": notebook_id,
|
||||
"SOURCE_NAME": filename,
|
||||
"SOURCE_ID": source_id,
|
||||
})
|
||||
body = json.dumps(
|
||||
{
|
||||
"PROJECT_ID": notebook_id,
|
||||
"SOURCE_NAME": filename,
|
||||
"SOURCE_ID": source_id,
|
||||
}
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(url, headers=headers, content=body)
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ import os
|
|||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
|
@ -78,7 +78,7 @@ class AuthTokens:
|
|||
return "; ".join(f"{k}={v}" for k, v in self.cookies.items())
|
||||
|
||||
@classmethod
|
||||
async def from_storage(cls, path: Optional[Path] = None) -> "AuthTokens":
|
||||
async def from_storage(cls, path: Path | None = None) -> "AuthTokens":
|
||||
"""Create AuthTokens from Playwright storage state file.
|
||||
|
||||
This is the recommended way to create AuthTokens for programmatic use.
|
||||
|
|
@ -134,8 +134,7 @@ def extract_cookies_from_storage(storage_state: dict[str, Any]) -> dict[str, str
|
|||
missing = MINIMUM_REQUIRED_COOKIES - set(cookies.keys())
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"Missing required cookies: {missing}\n"
|
||||
f"Run 'notebooklm login' to authenticate."
|
||||
f"Missing required cookies: {missing}\nRun 'notebooklm login' to authenticate."
|
||||
)
|
||||
|
||||
return cookies
|
||||
|
|
@ -164,8 +163,7 @@ def extract_csrf_from_html(html: str, final_url: str = "") -> str:
|
|||
# Check if we were redirected to login page
|
||||
if "accounts.google.com" in final_url or "accounts.google.com" in html:
|
||||
raise ValueError(
|
||||
"Authentication expired or invalid. "
|
||||
"Run 'notebooklm login' to re-authenticate."
|
||||
"Authentication expired or invalid. Run 'notebooklm login' to re-authenticate."
|
||||
)
|
||||
raise ValueError(
|
||||
f"CSRF token not found in HTML. Final URL: {final_url}\n"
|
||||
|
|
@ -196,8 +194,7 @@ def extract_session_id_from_html(html: str, final_url: str = "") -> str:
|
|||
if not match:
|
||||
if "accounts.google.com" in final_url or "accounts.google.com" in html:
|
||||
raise ValueError(
|
||||
"Authentication expired or invalid. "
|
||||
"Run 'notebooklm login' to re-authenticate."
|
||||
"Authentication expired or invalid. Run 'notebooklm login' to re-authenticate."
|
||||
)
|
||||
raise ValueError(
|
||||
f"Session ID not found in HTML. Final URL: {final_url}\n"
|
||||
|
|
@ -206,7 +203,7 @@ def extract_session_id_from_html(html: str, final_url: str = "") -> str:
|
|||
return match.group(1)
|
||||
|
||||
|
||||
def _load_storage_state(path: Optional[Path] = None) -> dict[str, Any]:
|
||||
def _load_storage_state(path: Path | None = None) -> dict[str, Any]:
|
||||
"""Load Playwright storage state from file or environment variable.
|
||||
|
||||
This is a shared helper used by load_auth_from_storage() and load_httpx_cookies()
|
||||
|
|
@ -231,8 +228,7 @@ def _load_storage_state(path: Optional[Path] = None) -> dict[str, Any]:
|
|||
if path:
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Storage file not found: {path}\n"
|
||||
f"Run 'notebooklm login' to authenticate first."
|
||||
f"Storage file not found: {path}\nRun 'notebooklm login' to authenticate first."
|
||||
)
|
||||
return json.loads(path.read_text())
|
||||
|
||||
|
|
@ -257,7 +253,7 @@ def _load_storage_state(path: Optional[Path] = None) -> dict[str, Any]:
|
|||
raise ValueError(
|
||||
"NOTEBOOKLM_AUTH_JSON must contain valid Playwright storage state "
|
||||
"with a 'cookies' key.\n"
|
||||
"Expected format: {\"cookies\": [{\"name\": \"SID\", \"value\": \"...\", ...}]}"
|
||||
'Expected format: {"cookies": [{"name": "SID", "value": "...", ...}]}'
|
||||
)
|
||||
return storage_state
|
||||
|
||||
|
|
@ -266,14 +262,13 @@ def _load_storage_state(path: Optional[Path] = None) -> dict[str, Any]:
|
|||
|
||||
if not storage_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Storage file not found: {storage_path}\n"
|
||||
f"Run 'notebooklm login' to authenticate first."
|
||||
f"Storage file not found: {storage_path}\nRun 'notebooklm login' to authenticate first."
|
||||
)
|
||||
|
||||
return json.loads(storage_path.read_text())
|
||||
|
||||
|
||||
def load_auth_from_storage(path: Optional[Path] = None) -> dict[str, str]:
|
||||
def load_auth_from_storage(path: Path | None = None) -> dict[str, str]:
|
||||
"""Load Google cookies from storage.
|
||||
|
||||
Loads authentication cookies with the following precedence:
|
||||
|
|
@ -336,7 +331,7 @@ def _is_allowed_cookie_domain(domain: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def load_httpx_cookies(path: Optional[Path] = None) -> "httpx.Cookies":
|
||||
def load_httpx_cookies(path: Path | None = None) -> "httpx.Cookies":
|
||||
"""Load cookies as an httpx.Cookies object for authenticated downloads.
|
||||
|
||||
Unlike load_auth_from_storage() which returns a simple dict, this function
|
||||
|
|
|
|||
|
|
@ -16,67 +16,65 @@ Re-exports from helpers for backward compatibility with tests.
|
|||
"""
|
||||
|
||||
# Command groups (subcommand style)
|
||||
from .source import source
|
||||
from .artifact import artifact
|
||||
from .generate import generate
|
||||
from .chat import register_chat_commands
|
||||
from .download import download
|
||||
from .generate import generate
|
||||
from .helpers import (
|
||||
# Display
|
||||
ARTIFACT_TYPE_DISPLAY,
|
||||
ARTIFACT_TYPE_MAP,
|
||||
BROWSER_PROFILE_DIR,
|
||||
# Context
|
||||
CONTEXT_FILE,
|
||||
clear_context,
|
||||
# Console
|
||||
console,
|
||||
detect_source_type,
|
||||
get_artifact_type_display,
|
||||
get_auth_tokens,
|
||||
# Auth
|
||||
get_client,
|
||||
get_current_conversation,
|
||||
get_current_notebook,
|
||||
get_source_type_display,
|
||||
handle_auth_error,
|
||||
# Errors
|
||||
handle_error,
|
||||
json_error_response,
|
||||
# Output
|
||||
json_output_response,
|
||||
require_notebook,
|
||||
resolve_artifact_id,
|
||||
resolve_notebook_id,
|
||||
resolve_source_id,
|
||||
# Async
|
||||
run_async,
|
||||
set_current_conversation,
|
||||
set_current_notebook,
|
||||
# Decorators
|
||||
with_client,
|
||||
)
|
||||
from .note import note
|
||||
from .skill import skill
|
||||
from .notebook import register_notebook_commands
|
||||
from .options import (
|
||||
artifact_option,
|
||||
generate_options,
|
||||
json_option,
|
||||
# Individual option decorators
|
||||
notebook_option,
|
||||
output_option,
|
||||
source_option,
|
||||
# Composite decorators
|
||||
standard_options,
|
||||
wait_option,
|
||||
)
|
||||
from .research import research
|
||||
|
||||
# Register functions (top-level command style)
|
||||
from .session import register_session_commands
|
||||
from .notebook import register_notebook_commands
|
||||
from .chat import register_chat_commands
|
||||
|
||||
from .helpers import (
|
||||
# Console
|
||||
console,
|
||||
# Async
|
||||
run_async,
|
||||
# Auth
|
||||
get_client,
|
||||
get_auth_tokens,
|
||||
# Context
|
||||
CONTEXT_FILE,
|
||||
BROWSER_PROFILE_DIR,
|
||||
get_current_notebook,
|
||||
set_current_notebook,
|
||||
clear_context,
|
||||
get_current_conversation,
|
||||
set_current_conversation,
|
||||
require_notebook,
|
||||
resolve_notebook_id,
|
||||
resolve_source_id,
|
||||
resolve_artifact_id,
|
||||
# Errors
|
||||
handle_error,
|
||||
handle_auth_error,
|
||||
# Decorators
|
||||
with_client,
|
||||
# Output
|
||||
json_output_response,
|
||||
json_error_response,
|
||||
# Display
|
||||
ARTIFACT_TYPE_DISPLAY,
|
||||
ARTIFACT_TYPE_MAP,
|
||||
get_artifact_type_display,
|
||||
detect_source_type,
|
||||
get_source_type_display,
|
||||
)
|
||||
|
||||
from .options import (
|
||||
# Individual option decorators
|
||||
notebook_option,
|
||||
json_option,
|
||||
wait_option,
|
||||
source_option,
|
||||
artifact_option,
|
||||
output_option,
|
||||
# Composite decorators
|
||||
standard_options,
|
||||
generate_options,
|
||||
)
|
||||
from .skill import skill
|
||||
from .source import source
|
||||
|
||||
__all__ = [
|
||||
# Command groups (subcommand style)
|
||||
|
|
|
|||
|
|
@ -19,13 +19,13 @@ from rich.table import Table
|
|||
from ..client import NotebookLMClient
|
||||
from ..rpc import ExportType
|
||||
from .helpers import (
|
||||
ARTIFACT_TYPE_MAP,
|
||||
console,
|
||||
get_artifact_type_display,
|
||||
json_output_response,
|
||||
require_notebook,
|
||||
resolve_artifact_id,
|
||||
with_client,
|
||||
json_output_response,
|
||||
get_artifact_type_display,
|
||||
ARTIFACT_TYPE_MAP,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -83,9 +83,7 @@ def artifact():
|
|||
def artifact_list(ctx, notebook_id, artifact_type, json_output, client_auth):
|
||||
"""List artifacts in a notebook."""
|
||||
nb_id = require_notebook(notebook_id)
|
||||
type_filter = (
|
||||
None if artifact_type == "all" else ARTIFACT_TYPE_MAP.get(artifact_type)
|
||||
)
|
||||
type_filter = None if artifact_type == "all" else ARTIFACT_TYPE_MAP.get(artifact_type)
|
||||
|
||||
async def _run():
|
||||
async with NotebookLMClient(client_auth) as client:
|
||||
|
|
@ -97,6 +95,7 @@ def artifact_list(ctx, notebook_id, artifact_type, json_output, client_auth):
|
|||
nb = await client.notebooks.get(nb_id)
|
||||
|
||||
if json_output:
|
||||
|
||||
def _get_status_str(art):
|
||||
if art.is_completed:
|
||||
return "completed"
|
||||
|
|
@ -118,9 +117,7 @@ def artifact_list(ctx, notebook_id, artifact_type, json_output, client_auth):
|
|||
"type_id": art.artifact_type,
|
||||
"status": _get_status_str(art),
|
||||
"status_id": art.status,
|
||||
"created_at": art.created_at.isoformat()
|
||||
if art.created_at
|
||||
else None,
|
||||
"created_at": art.created_at.isoformat() if art.created_at else None,
|
||||
}
|
||||
for i, art in enumerate(artifacts, 1)
|
||||
],
|
||||
|
|
@ -144,9 +141,7 @@ def artifact_list(ctx, notebook_id, artifact_type, json_output, client_auth):
|
|||
type_display = get_artifact_type_display(
|
||||
art.artifact_type, art.variant, art.report_subtype
|
||||
)
|
||||
created = (
|
||||
art.created_at.strftime("%Y-%m-%d %H:%M") if art.created_at else "-"
|
||||
)
|
||||
created = art.created_at.strftime("%Y-%m-%d %H:%M") if art.created_at else "-"
|
||||
status = (
|
||||
"completed"
|
||||
if art.is_completed
|
||||
|
|
@ -290,9 +285,7 @@ def artifact_delete(ctx, artifact_id, notebook_id, yes, client_auth):
|
|||
help="Notebook ID (uses current if not set). Supports partial IDs.",
|
||||
)
|
||||
@click.option("--title", required=True, help="Title for exported document")
|
||||
@click.option(
|
||||
"--type", "export_type", type=click.Choice(["docs", "sheets"]), default="docs"
|
||||
)
|
||||
@click.option("--type", "export_type", type=click.Choice(["docs", "sheets"]), default="docs")
|
||||
@with_client
|
||||
def artifact_export(ctx, artifact_id, notebook_id, title, export_type, client_auth):
|
||||
"""Export artifact to Google Docs/Sheets.
|
||||
|
|
@ -384,7 +377,8 @@ def artifact_wait(ctx, artifact_id, notebook_id, timeout, interval, json_output,
|
|||
|
||||
try:
|
||||
status = await client.artifacts.wait_for_completion(
|
||||
nb_id, resolved_id,
|
||||
nb_id,
|
||||
resolved_id,
|
||||
poll_interval=float(interval),
|
||||
timeout=float(timeout),
|
||||
)
|
||||
|
|
@ -410,11 +404,13 @@ def artifact_wait(ctx, artifact_id, notebook_id, timeout, interval, json_output,
|
|||
|
||||
except TimeoutError:
|
||||
if json_output:
|
||||
json_output_response({
|
||||
"artifact_id": resolved_id,
|
||||
"status": "timeout",
|
||||
"error": f"Timed out after {timeout} seconds",
|
||||
})
|
||||
json_output_response(
|
||||
{
|
||||
"artifact_id": resolved_id,
|
||||
"status": "timeout",
|
||||
"error": f"Timed out after {timeout} seconds",
|
||||
}
|
||||
)
|
||||
else:
|
||||
console.print(f"[red]✗ Timeout after {timeout}s[/red]")
|
||||
raise SystemExit(1)
|
||||
|
|
@ -463,10 +459,6 @@ def artifact_suggestions(ctx, notebook_id, source_ids, json_output, client_auth)
|
|||
table.add_row(str(i), suggestion.title, suggestion.description)
|
||||
|
||||
console.print(table)
|
||||
console.print(
|
||||
'\n[dim]Use the prompt with: notebooklm generate report "<prompt>"[/dim]'
|
||||
)
|
||||
console.print('\n[dim]Use the prompt with: notebooklm generate report "<prompt>"[/dim]')
|
||||
|
||||
return _run()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -13,10 +13,10 @@ from ..client import NotebookLMClient
|
|||
from ..types import ChatMode
|
||||
from .helpers import (
|
||||
console,
|
||||
require_notebook,
|
||||
with_client,
|
||||
get_current_conversation,
|
||||
require_notebook,
|
||||
set_current_conversation,
|
||||
with_client,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -32,12 +32,8 @@ def register_chat_commands(cli):
|
|||
default=None,
|
||||
help="Notebook ID (uses current if not set)",
|
||||
)
|
||||
@click.option(
|
||||
"--conversation-id", "-c", default=None, help="Continue a specific conversation"
|
||||
)
|
||||
@click.option(
|
||||
"--new", "new_conversation", is_flag=True, help="Start a new conversation"
|
||||
)
|
||||
@click.option("--conversation-id", "-c", default=None, help="Continue a specific conversation")
|
||||
@click.option("--new", "new_conversation", is_flag=True, help="Start a new conversation")
|
||||
@with_client
|
||||
def ask_cmd(ctx, question, notebook_id, conversation_id, new_conversation, client_auth):
|
||||
"""Ask a notebook a question.
|
||||
|
|
@ -68,9 +64,7 @@ def register_chat_commands(cli):
|
|||
if history and history[0]:
|
||||
last_conv = history[0][-1]
|
||||
effective_conv_id = (
|
||||
last_conv[0]
|
||||
if isinstance(last_conv, list)
|
||||
else str(last_conv)
|
||||
last_conv[0] if isinstance(last_conv, list) else str(last_conv)
|
||||
)
|
||||
console.print(
|
||||
f"[dim]Continuing conversation {effective_conv_id[:8]}...[/dim]"
|
||||
|
|
@ -78,9 +72,7 @@ def register_chat_commands(cli):
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
result = await client.chat.ask(
|
||||
nb_id, question, conversation_id=effective_conv_id
|
||||
)
|
||||
result = await client.chat.ask(nb_id, question, conversation_id=effective_conv_id)
|
||||
|
||||
if result.conversation_id:
|
||||
set_current_conversation(result.conversation_id)
|
||||
|
|
@ -111,9 +103,7 @@ def register_chat_commands(cli):
|
|||
default=None,
|
||||
help="Predefined chat mode",
|
||||
)
|
||||
@click.option(
|
||||
"--persona", default=None, help="Custom persona prompt (up to 10,000 chars)"
|
||||
)
|
||||
@click.option("--persona", default=None, help="Custom persona prompt (up to 10,000 chars)")
|
||||
@click.option(
|
||||
"--response-length",
|
||||
type=click.Choice(["default", "longer", "shorter"]),
|
||||
|
|
@ -206,6 +196,7 @@ def register_chat_commands(cli):
|
|||
notebooklm history -n nb123 # Show history for specific notebook
|
||||
notebooklm history --clear # Clear local cache
|
||||
"""
|
||||
|
||||
async def _run():
|
||||
async with NotebookLMClient(client_auth) as client:
|
||||
if clear_cache:
|
||||
|
|
@ -228,9 +219,7 @@ def register_chat_commands(cli):
|
|||
table.add_column("#", style="dim")
|
||||
table.add_column("Conversation ID", style="cyan")
|
||||
for i, conv in enumerate(conversations, 1):
|
||||
conv_id = (
|
||||
conv[0] if isinstance(conv, list) and conv else str(conv)
|
||||
)
|
||||
conv_id = conv[0] if isinstance(conv, list) and conv else str(conv)
|
||||
table.add_row(str(i), conv_id)
|
||||
console.print(table)
|
||||
console.print(
|
||||
|
|
|
|||
|
|
@ -13,16 +13,16 @@ from typing import Any
|
|||
|
||||
import click
|
||||
|
||||
from ..auth import AuthTokens, load_auth_from_storage, fetch_tokens
|
||||
from ..auth import AuthTokens, fetch_tokens, load_auth_from_storage
|
||||
from ..client import NotebookLMClient
|
||||
from ..types import Artifact
|
||||
from .download_helpers import ArtifactDict, artifact_title_to_filename, select_artifact
|
||||
from .helpers import (
|
||||
console,
|
||||
run_async,
|
||||
require_notebook,
|
||||
handle_error,
|
||||
require_notebook,
|
||||
run_async,
|
||||
)
|
||||
from .download_helpers import select_artifact, artifact_title_to_filename, ArtifactDict
|
||||
|
||||
|
||||
@click.group()
|
||||
|
|
@ -166,9 +166,7 @@ async def _download_artifacts_generic(
|
|||
|
||||
# Handle --all flag
|
||||
if download_all:
|
||||
output_dir = (
|
||||
Path(output_path) if output_path else Path(default_output_dir)
|
||||
)
|
||||
output_dir = Path(output_path) if output_path else Path(default_output_dir)
|
||||
|
||||
if dry_run:
|
||||
return {
|
||||
|
|
@ -199,9 +197,7 @@ async def _download_artifacts_generic(
|
|||
for i, artifact in enumerate(type_artifacts, 1):
|
||||
# Progress indicator
|
||||
if not json_output:
|
||||
console.print(
|
||||
f"[dim]Downloading {i}/{total}:[/dim] {artifact['title']}"
|
||||
)
|
||||
console.print(f"[dim]Downloading {i}/{total}:[/dim] {artifact['title']}")
|
||||
|
||||
# Generate safe name
|
||||
item_name = artifact_title_to_filename(
|
||||
|
|
@ -220,7 +216,10 @@ async def _download_artifacts_generic(
|
|||
"id": artifact["id"],
|
||||
"title": artifact["title"],
|
||||
"filename": item_name,
|
||||
**(skip_info or {"status": "skipped", "reason": "conflict resolution failed"}),
|
||||
**(
|
||||
skip_info
|
||||
or {"status": "skipped", "reason": "conflict resolution failed"}
|
||||
),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
|
@ -232,9 +231,7 @@ async def _download_artifacts_generic(
|
|||
# Download
|
||||
try:
|
||||
# Download using dispatch
|
||||
await download_fn(
|
||||
nb_id, str(item_path), artifact_id=str(artifact["id"])
|
||||
)
|
||||
await download_fn(nb_id, str(item_path), artifact_id=str(artifact["id"]))
|
||||
|
||||
results.append(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Helper functions for download commands."""
|
||||
|
||||
import re
|
||||
from typing import Optional, TypedDict
|
||||
from typing import TypedDict
|
||||
|
||||
# Reserve space for " (999)" suffix when handling duplicate filenames
|
||||
DUPLICATE_SUFFIX_RESERVE = 7
|
||||
|
|
@ -19,8 +19,8 @@ def select_artifact(
|
|||
artifacts: list[ArtifactDict],
|
||||
latest: bool = True,
|
||||
earliest: bool = False,
|
||||
name: Optional[str] = None,
|
||||
artifact_id: Optional[str] = None,
|
||||
name: str | None = None,
|
||||
artifact_id: str | None = None,
|
||||
) -> tuple[ArtifactDict, str]:
|
||||
"""
|
||||
Select an artifact from a list based on criteria.
|
||||
|
|
@ -106,10 +106,10 @@ def artifact_title_to_filename(
|
|||
"""
|
||||
# Sanitize: replace invalid chars with underscore
|
||||
# Invalid chars: / \ : * ? " < > |
|
||||
sanitized = re.sub(r'[/\\:*?"<>|]', '_', title)
|
||||
sanitized = re.sub(r'[/\\:*?"<>|]', "_", title)
|
||||
|
||||
# Remove leading/trailing whitespace and dots
|
||||
sanitized = sanitized.strip('. ')
|
||||
sanitized = sanitized.strip(". ")
|
||||
|
||||
# Fallback for empty titles
|
||||
if not sanitized:
|
||||
|
|
@ -120,7 +120,7 @@ def artifact_title_to_filename(
|
|||
|
||||
# Truncate if too long
|
||||
if len(sanitized) > effective_max:
|
||||
sanitized = sanitized[:effective_max].rstrip('. ')
|
||||
sanitized = sanitized[:effective_max].rstrip(". ")
|
||||
|
||||
# Build initial filename
|
||||
base = sanitized
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ Commands:
|
|||
report Generate report
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
|
||||
|
|
@ -20,23 +20,23 @@ from ..client import NotebookLMClient
|
|||
from ..types import (
|
||||
AudioFormat,
|
||||
AudioLength,
|
||||
VideoFormat,
|
||||
VideoStyle,
|
||||
QuizQuantity,
|
||||
QuizDifficulty,
|
||||
InfographicOrientation,
|
||||
GenerationStatus,
|
||||
InfographicDetail,
|
||||
InfographicOrientation,
|
||||
QuizDifficulty,
|
||||
QuizQuantity,
|
||||
ReportFormat,
|
||||
SlideDeckFormat,
|
||||
SlideDeckLength,
|
||||
ReportFormat,
|
||||
GenerationStatus,
|
||||
VideoFormat,
|
||||
VideoStyle,
|
||||
)
|
||||
from .helpers import (
|
||||
console,
|
||||
json_error_response,
|
||||
json_output_response,
|
||||
require_notebook,
|
||||
with_client,
|
||||
json_output_response,
|
||||
json_error_response,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -48,7 +48,7 @@ async def handle_generation_result(
|
|||
wait: bool = False,
|
||||
json_output: bool = False,
|
||||
timeout: float = 300.0,
|
||||
) -> Optional[GenerationStatus]:
|
||||
) -> GenerationStatus | None:
|
||||
"""Handle generation result with optional waiting and output formatting.
|
||||
|
||||
Consolidates common pattern across all generate commands:
|
||||
|
|
@ -98,12 +98,8 @@ async def handle_generation_result(
|
|||
# Wait for completion if requested
|
||||
if wait and task_id:
|
||||
if not json_output:
|
||||
console.print(
|
||||
f"[yellow]Generating {artifact_type}...[/yellow] Task: {task_id}"
|
||||
)
|
||||
status = await client.artifacts.wait_for_completion(
|
||||
notebook_id, task_id, timeout=timeout
|
||||
)
|
||||
console.print(f"[yellow]Generating {artifact_type}...[/yellow] Task: {task_id}")
|
||||
status = await client.artifacts.wait_for_completion(notebook_id, task_id, timeout=timeout)
|
||||
|
||||
# Output status
|
||||
_output_generation_status(status, artifact_type, json_output)
|
||||
|
|
@ -111,17 +107,17 @@ async def handle_generation_result(
|
|||
return status if isinstance(status, GenerationStatus) else None
|
||||
|
||||
|
||||
def _output_generation_status(
|
||||
status: Any, artifact_type: str, json_output: bool
|
||||
) -> None:
|
||||
def _output_generation_status(status: Any, artifact_type: str, json_output: bool) -> None:
|
||||
"""Output generation status in appropriate format."""
|
||||
if json_output:
|
||||
if hasattr(status, "is_complete") and status.is_complete:
|
||||
json_output_response({
|
||||
"artifact_id": getattr(status, "task_id", None),
|
||||
"status": "completed",
|
||||
"url": getattr(status, "url", None),
|
||||
})
|
||||
json_output_response(
|
||||
{
|
||||
"artifact_id": getattr(status, "task_id", None),
|
||||
"status": "completed",
|
||||
"url": getattr(status, "url", None),
|
||||
}
|
||||
)
|
||||
elif hasattr(status, "is_failed") and status.is_failed:
|
||||
json_error_response(
|
||||
"GENERATION_FAILED",
|
||||
|
|
@ -205,9 +201,7 @@ def generate():
|
|||
default="default",
|
||||
)
|
||||
@click.option("--language", default="en")
|
||||
@click.option(
|
||||
"--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)"
|
||||
)
|
||||
@click.option("--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)")
|
||||
@click.option("--json", "json_output", is_flag=True, help="Output as JSON")
|
||||
@with_client
|
||||
def generate_audio(
|
||||
|
|
@ -250,9 +244,7 @@ def generate_audio(
|
|||
audio_format=format_map[audio_format],
|
||||
audio_length=length_map[audio_length],
|
||||
)
|
||||
await handle_generation_result(
|
||||
client, nb_id, result, "audio", wait, json_output
|
||||
)
|
||||
await handle_generation_result(client, nb_id, result, "audio", wait, json_output)
|
||||
|
||||
return _run()
|
||||
|
||||
|
|
@ -290,9 +282,7 @@ def generate_audio(
|
|||
default="auto",
|
||||
)
|
||||
@click.option("--language", default="en")
|
||||
@click.option(
|
||||
"--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)"
|
||||
)
|
||||
@click.option("--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)")
|
||||
@click.option("--json", "json_output", is_flag=True, help="Output as JSON")
|
||||
@with_client
|
||||
def generate_video(
|
||||
|
|
@ -358,9 +348,7 @@ def generate_video(
|
|||
default="default",
|
||||
)
|
||||
@click.option("--language", default="en")
|
||||
@click.option(
|
||||
"--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)"
|
||||
)
|
||||
@click.option("--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)")
|
||||
@with_client
|
||||
def generate_slide_deck(
|
||||
ctx, description, notebook_id, deck_format, deck_length, language, wait, client_auth
|
||||
|
|
@ -391,9 +379,7 @@ def generate_slide_deck(
|
|||
slide_format=format_map[deck_format],
|
||||
slide_length=length_map[deck_length],
|
||||
)
|
||||
await handle_generation_result(
|
||||
client, nb_id, result, "slide deck", wait
|
||||
)
|
||||
await handle_generation_result(client, nb_id, result, "slide deck", wait)
|
||||
|
||||
return _run()
|
||||
|
||||
|
|
@ -407,15 +393,9 @@ def generate_slide_deck(
|
|||
default=None,
|
||||
help="Notebook ID (uses current if not set)",
|
||||
)
|
||||
@click.option(
|
||||
"--quantity", type=click.Choice(["fewer", "standard", "more"]), default="standard"
|
||||
)
|
||||
@click.option(
|
||||
"--difficulty", type=click.Choice(["easy", "medium", "hard"]), default="medium"
|
||||
)
|
||||
@click.option(
|
||||
"--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)"
|
||||
)
|
||||
@click.option("--quantity", type=click.Choice(["fewer", "standard", "more"]), default="standard")
|
||||
@click.option("--difficulty", type=click.Choice(["easy", "medium", "hard"]), default="medium")
|
||||
@click.option("--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)")
|
||||
@with_client
|
||||
def generate_quiz(ctx, description, notebook_id, quantity, difficulty, wait, client_auth):
|
||||
"""Generate quiz.
|
||||
|
|
@ -459,15 +439,9 @@ def generate_quiz(ctx, description, notebook_id, quantity, difficulty, wait, cli
|
|||
default=None,
|
||||
help="Notebook ID (uses current if not set)",
|
||||
)
|
||||
@click.option(
|
||||
"--quantity", type=click.Choice(["fewer", "standard", "more"]), default="standard"
|
||||
)
|
||||
@click.option(
|
||||
"--difficulty", type=click.Choice(["easy", "medium", "hard"]), default="medium"
|
||||
)
|
||||
@click.option(
|
||||
"--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)"
|
||||
)
|
||||
@click.option("--quantity", type=click.Choice(["fewer", "standard", "more"]), default="standard")
|
||||
@click.option("--difficulty", type=click.Choice(["easy", "medium", "hard"]), default="medium")
|
||||
@click.option("--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)")
|
||||
@with_client
|
||||
def generate_flashcards(ctx, description, notebook_id, quantity, difficulty, wait, client_auth):
|
||||
"""Generate flashcards.
|
||||
|
|
@ -522,9 +496,7 @@ def generate_flashcards(ctx, description, notebook_id, quantity, difficulty, wai
|
|||
default="standard",
|
||||
)
|
||||
@click.option("--language", default="en")
|
||||
@click.option(
|
||||
"--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)"
|
||||
)
|
||||
@click.option("--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)")
|
||||
@with_client
|
||||
def generate_infographic(
|
||||
ctx, description, notebook_id, orientation, detail, language, wait, client_auth
|
||||
|
|
@ -572,9 +544,7 @@ def generate_infographic(
|
|||
help="Notebook ID (uses current if not set)",
|
||||
)
|
||||
@click.option("--language", default="en")
|
||||
@click.option(
|
||||
"--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)"
|
||||
)
|
||||
@click.option("--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)")
|
||||
@with_client
|
||||
def generate_data_table(ctx, description, notebook_id, language, wait, client_auth):
|
||||
"""Generate data table.
|
||||
|
|
@ -621,9 +591,7 @@ def generate_mind_map(ctx, notebook_id, client_auth):
|
|||
mind_map = result.get("mind_map", {})
|
||||
if isinstance(mind_map, dict):
|
||||
console.print(f" Root: {mind_map.get('name', '-')}")
|
||||
console.print(
|
||||
f" Children: {len(mind_map.get('children', []))} nodes"
|
||||
)
|
||||
console.print(f" Children: {len(mind_map.get('children', []))} nodes")
|
||||
else:
|
||||
console.print(result)
|
||||
else:
|
||||
|
|
@ -648,9 +616,7 @@ def generate_mind_map(ctx, notebook_id, client_auth):
|
|||
default=None,
|
||||
help="Notebook ID (uses current if not set)",
|
||||
)
|
||||
@click.option(
|
||||
"--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)"
|
||||
)
|
||||
@click.option("--wait/--no-wait", default=False, help="Wait for completion (default: no-wait)")
|
||||
@with_client
|
||||
def generate_report_cmd(ctx, description, report_format, notebook_id, wait, client_auth):
|
||||
"""Generate a report (briefing doc, study guide, blog post, or custom).
|
||||
|
|
|
|||
|
|
@ -20,24 +20,28 @@ class SectionedGroup(click.Group):
|
|||
"""
|
||||
|
||||
# Regular commands - show help text
|
||||
command_sections = OrderedDict([
|
||||
("Session", ["login", "use", "status", "clear"]),
|
||||
("Notebooks", ["list", "create", "delete", "rename", "share", "summary"]),
|
||||
("Chat", ["ask", "configure", "history"]),
|
||||
])
|
||||
command_sections = OrderedDict(
|
||||
[
|
||||
("Session", ["login", "use", "status", "clear"]),
|
||||
("Notebooks", ["list", "create", "delete", "rename", "share", "summary"]),
|
||||
("Chat", ["ask", "configure", "history"]),
|
||||
]
|
||||
)
|
||||
|
||||
# Command groups - show sorted subcommands instead of help text
|
||||
command_groups = OrderedDict([
|
||||
("Command Groups (use: notebooklm <group> <command>)",
|
||||
["source", "artifact", "note", "research"]),
|
||||
("Artifact Actions (use: notebooklm <action> <type>)",
|
||||
["generate", "download"]),
|
||||
])
|
||||
command_groups = OrderedDict(
|
||||
[
|
||||
(
|
||||
"Command Groups (use: notebooklm <group> <command>)",
|
||||
["source", "artifact", "note", "research"],
|
||||
),
|
||||
("Artifact Actions (use: notebooklm <action> <type>)", ["generate", "download"]),
|
||||
]
|
||||
)
|
||||
|
||||
def format_commands(self, ctx, formatter):
|
||||
"""Override to display commands in sections."""
|
||||
commands = {name: self.get_command(ctx, name)
|
||||
for name in self.list_commands(ctx)}
|
||||
commands = {name: self.get_command(ctx, name) for name in self.list_commands(ctx)}
|
||||
|
||||
# Regular command sections (show help text)
|
||||
for section, cmd_names in self.command_sections.items():
|
||||
|
|
@ -67,9 +71,13 @@ class SectionedGroup(click.Group):
|
|||
# Safety net: show any commands not in any section
|
||||
all_listed = set(sum(self.command_sections.values(), []))
|
||||
all_listed |= set(sum(self.command_groups.values(), []))
|
||||
unlisted = [(n, c) for n, c in commands.items()
|
||||
if n not in all_listed and c is not None and not c.hidden]
|
||||
unlisted = [
|
||||
(n, c)
|
||||
for n, c in commands.items()
|
||||
if n not in all_listed and c is not None and not c.hidden
|
||||
]
|
||||
if unlisted:
|
||||
with formatter.section("Other"):
|
||||
formatter.write_dl([(n, c.get_short_help_str(limit=formatter.width))
|
||||
for n, c in unlisted])
|
||||
formatter.write_dl(
|
||||
[(n, c.get_short_help_str(limit=formatter.width)) for n, c in unlisted]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,10 +18,10 @@ from rich.console import Console
|
|||
|
||||
from ..auth import (
|
||||
AuthTokens,
|
||||
load_auth_from_storage,
|
||||
fetch_tokens,
|
||||
load_auth_from_storage,
|
||||
)
|
||||
from ..paths import get_context_path, get_browser_profile_dir
|
||||
from ..paths import get_browser_profile_dir, get_context_path
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -116,7 +116,7 @@ def get_current_notebook() -> str | None:
|
|||
try:
|
||||
data = json.loads(context_file.read_text())
|
||||
return data.get("notebook_id")
|
||||
except (json.JSONDecodeError, IOError):
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -154,7 +154,7 @@ def get_current_conversation() -> str | None:
|
|||
try:
|
||||
data = json.loads(context_file.read_text())
|
||||
return data.get("conversation_id")
|
||||
except (json.JSONDecodeError, IOError):
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -170,7 +170,7 @@ def set_current_conversation(conversation_id: str | None):
|
|||
elif "conversation_id" in data:
|
||||
del data["conversation_id"]
|
||||
context_file.write_text(json.dumps(data, indent=2))
|
||||
except (json.JSONDecodeError, IOError):
|
||||
except (OSError, json.JSONDecodeError):
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -247,8 +247,7 @@ async def _resolve_partial_id(
|
|||
return partial_id
|
||||
|
||||
items = await list_fn()
|
||||
matches = [item for item in items
|
||||
if item.id.lower().startswith(partial_id.lower())]
|
||||
matches = [item for item in items if item.id.lower().startswith(partial_id.lower())]
|
||||
|
||||
if len(matches) == 1:
|
||||
if matches[0].id != partial_id:
|
||||
|
|
@ -315,13 +314,9 @@ def handle_error(e: Exception):
|
|||
def handle_auth_error(json_output: bool = False):
|
||||
"""Handle authentication errors."""
|
||||
if json_output:
|
||||
json_error_response(
|
||||
"AUTH_REQUIRED", "Auth not found. Run 'notebooklm login' first."
|
||||
)
|
||||
json_error_response("AUTH_REQUIRED", "Auth not found. Run 'notebooklm login' first.")
|
||||
else:
|
||||
console.print(
|
||||
"[red]Not logged in. Run 'notebooklm login' first.[/red]"
|
||||
)
|
||||
console.print("[red]Not logged in. Run 'notebooklm login' first.[/red]")
|
||||
raise SystemExit(1)
|
||||
|
||||
|
||||
|
|
@ -358,6 +353,7 @@ def with_client(f):
|
|||
Returns:
|
||||
Decorated function with Click pass_context
|
||||
"""
|
||||
|
||||
@wraps(f)
|
||||
@click.pass_context
|
||||
def wrapper(ctx, *args, **kwargs):
|
||||
|
|
@ -374,6 +370,7 @@ def with_client(f):
|
|||
json_error_response("ERROR", str(e))
|
||||
else:
|
||||
handle_error(e)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
|
|
@ -389,9 +386,7 @@ def json_output_response(data: dict) -> None:
|
|||
|
||||
def json_error_response(code: str, message: str) -> None:
|
||||
"""Print JSON error and exit."""
|
||||
console.print(
|
||||
json.dumps({"error": True, "code": code, "message": message}, indent=2)
|
||||
)
|
||||
console.print(json.dumps({"error": True, "code": code, "message": message}, indent=2))
|
||||
raise SystemExit(1)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ def note_list(ctx, notebook_id, client_auth):
|
|||
table.add_row(
|
||||
n.id,
|
||||
n.title or "Untitled",
|
||||
preview + "..." if len(n.content or "") > 50 else preview
|
||||
preview + "..." if len(n.content or "") > 50 else preview,
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
|
|
|||
|
|
@ -14,13 +14,13 @@ from rich.table import Table
|
|||
|
||||
from ..client import NotebookLMClient
|
||||
from .helpers import (
|
||||
console,
|
||||
require_notebook,
|
||||
with_client,
|
||||
json_output_response,
|
||||
get_current_notebook,
|
||||
clear_context,
|
||||
console,
|
||||
get_current_notebook,
|
||||
json_output_response,
|
||||
require_notebook,
|
||||
resolve_notebook_id,
|
||||
with_client,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -32,6 +32,7 @@ def register_notebook_commands(cli):
|
|||
@with_client
|
||||
def list_cmd(ctx, json_output, client_auth):
|
||||
"""List all notebooks."""
|
||||
|
||||
async def _run():
|
||||
async with NotebookLMClient(client_auth) as client:
|
||||
notebooks = await client.notebooks.list()
|
||||
|
|
@ -44,9 +45,7 @@ def register_notebook_commands(cli):
|
|||
"id": nb.id,
|
||||
"title": nb.title,
|
||||
"is_owner": nb.is_owner,
|
||||
"created_at": nb.created_at.isoformat()
|
||||
if nb.created_at
|
||||
else None,
|
||||
"created_at": nb.created_at.isoformat() if nb.created_at else None,
|
||||
}
|
||||
for i, nb in enumerate(notebooks, 1)
|
||||
],
|
||||
|
|
@ -76,6 +75,7 @@ def register_notebook_commands(cli):
|
|||
@with_client
|
||||
def create_cmd(ctx, title, json_output, client_auth):
|
||||
"""Create a new notebook."""
|
||||
|
||||
async def _run():
|
||||
async with NotebookLMClient(client_auth) as client:
|
||||
nb = await client.notebooks.create(title)
|
||||
|
|
@ -85,9 +85,7 @@ def register_notebook_commands(cli):
|
|||
"notebook": {
|
||||
"id": nb.id,
|
||||
"title": nb.title,
|
||||
"created_at": nb.created_at.isoformat()
|
||||
if nb.created_at
|
||||
else None,
|
||||
"created_at": nb.created_at.isoformat() if nb.created_at else None,
|
||||
}
|
||||
}
|
||||
json_output_response(data)
|
||||
|
|
|
|||
|
|
@ -13,9 +13,9 @@ from rich.table import Table
|
|||
from ..client import NotebookLMClient
|
||||
from .helpers import (
|
||||
console,
|
||||
json_output_response,
|
||||
require_notebook,
|
||||
with_client,
|
||||
json_output_response,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -102,9 +102,7 @@ def research_status(ctx, notebook_id, json_output, client_auth):
|
|||
if summary:
|
||||
console.print(f"\n[bold]Summary:[/bold]\n{summary[:500]}")
|
||||
|
||||
console.print(
|
||||
"\n[dim]Use 'research wait --import-all' to import sources[/dim]"
|
||||
)
|
||||
console.print("\n[dim]Use 'research wait --import-all' to import sources[/dim]")
|
||||
else:
|
||||
console.print(f"[yellow]Status: {status_val}[/yellow]")
|
||||
|
||||
|
|
@ -134,9 +132,7 @@ def research_status(ctx, notebook_id, json_output, client_auth):
|
|||
@click.option("--import-all", is_flag=True, help="Import all found sources when done")
|
||||
@click.option("--json", "json_output", is_flag=True, help="Output as JSON")
|
||||
@with_client
|
||||
def research_wait(
|
||||
ctx, notebook_id, timeout, interval, import_all, json_output, client_auth
|
||||
):
|
||||
def research_wait(ctx, notebook_id, timeout, interval, import_all, json_output, client_auth):
|
||||
"""Wait for research to complete.
|
||||
|
||||
Blocks until research is completed or timeout is reached.
|
||||
|
|
@ -195,9 +191,7 @@ def research_wait(
|
|||
"sources": sources,
|
||||
}
|
||||
if import_all and sources and task_id:
|
||||
imported = await client.research.import_sources(
|
||||
nb_id, task_id, sources
|
||||
)
|
||||
imported = await client.research.import_sources(nb_id, task_id, sources)
|
||||
result["imported"] = len(imported)
|
||||
result["imported_sources"] = imported
|
||||
json_output_response(result)
|
||||
|
|
@ -207,9 +201,7 @@ def research_wait(
|
|||
|
||||
if import_all and sources and task_id:
|
||||
with console.status("Importing sources..."):
|
||||
imported = await client.research.import_sources(
|
||||
nb_id, task_id, sources
|
||||
)
|
||||
imported = await client.research.import_sources(nb_id, task_id, sources)
|
||||
console.print(f"[green]Imported {len(imported)} sources[/green]")
|
||||
|
||||
return _run()
|
||||
|
|
|
|||
|
|
@ -17,20 +17,20 @@ from rich.table import Table
|
|||
from ..auth import AuthTokens
|
||||
from ..client import NotebookLMClient
|
||||
from ..paths import (
|
||||
get_storage_path,
|
||||
get_context_path,
|
||||
get_browser_profile_dir,
|
||||
get_context_path,
|
||||
get_path_info,
|
||||
get_storage_path,
|
||||
)
|
||||
from .helpers import (
|
||||
clear_context,
|
||||
console,
|
||||
run_async,
|
||||
get_client,
|
||||
get_current_notebook,
|
||||
set_current_notebook,
|
||||
clear_context,
|
||||
json_output_response,
|
||||
resolve_notebook_id,
|
||||
run_async,
|
||||
set_current_notebook,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -216,7 +216,9 @@ def register_session_commands(cli):
|
|||
|
||||
# Show if NOTEBOOKLM_AUTH_JSON is set
|
||||
if os.environ.get("NOTEBOOKLM_AUTH_JSON"):
|
||||
console.print("[yellow]Note: NOTEBOOKLM_AUTH_JSON is set (inline auth active)[/yellow]\n")
|
||||
console.print(
|
||||
"[yellow]Note: NOTEBOOKLM_AUTH_JSON is set (inline auth active)[/yellow]\n"
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
return
|
||||
|
|
@ -254,11 +256,9 @@ def register_session_commands(cli):
|
|||
if conversation_id:
|
||||
table.add_row("Conversation", conversation_id)
|
||||
else:
|
||||
table.add_row(
|
||||
"Conversation", "[dim]None (will auto-select on next ask)[/dim]"
|
||||
)
|
||||
table.add_row("Conversation", "[dim]None (will auto-select on next ask)[/dim]")
|
||||
console.print(table)
|
||||
except (json.JSONDecodeError, IOError):
|
||||
except (OSError, json.JSONDecodeError):
|
||||
if json_output:
|
||||
json_data = {
|
||||
"has_context": True,
|
||||
|
|
|
|||
|
|
@ -6,19 +6,17 @@ Commands for managing the Claude Code skill integration.
|
|||
import re
|
||||
from importlib import resources
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
from .helpers import console
|
||||
|
||||
|
||||
# Skill paths
|
||||
SKILL_DEST_DIR = Path.home() / ".claude" / "skills" / "notebooklm"
|
||||
SKILL_DEST = SKILL_DEST_DIR / "SKILL.md"
|
||||
|
||||
|
||||
def get_skill_source_content() -> Optional[str]:
|
||||
def get_skill_source_content() -> str | None:
|
||||
"""Read the skill source file from package data."""
|
||||
try:
|
||||
# Python 3.9+ way to read package data (use / operator for path traversal)
|
||||
|
|
@ -31,6 +29,7 @@ def get_package_version() -> str:
|
|||
"""Get the current package version."""
|
||||
try:
|
||||
from .. import __version__
|
||||
|
||||
return __version__
|
||||
except ImportError:
|
||||
return "unknown"
|
||||
|
|
@ -44,7 +43,7 @@ def get_skill_version(skill_path: Path) -> str | None:
|
|||
with open(skill_path) as f:
|
||||
content = f.read(500) # Read first 500 chars
|
||||
|
||||
match = re.search(r'notebooklm-py v([\d.]+)', content)
|
||||
match = re.search(r"notebooklm-py v([\d.]+)", content)
|
||||
return match.group(1) if match else None
|
||||
|
||||
|
||||
|
|
@ -116,7 +115,9 @@ def status():
|
|||
|
||||
if skill_version and skill_version != cli_version:
|
||||
console.print("")
|
||||
console.print("[yellow]Version mismatch![/yellow] Run [cyan]notebooklm skill install[/cyan] to update.")
|
||||
console.print(
|
||||
"[yellow]Version mismatch![/yellow] Run [cyan]notebooklm skill install[/cyan] to update."
|
||||
)
|
||||
|
||||
|
||||
@skill.command()
|
||||
|
|
|
|||
|
|
@ -22,11 +22,11 @@ from rich.table import Table
|
|||
from ..client import NotebookLMClient
|
||||
from .helpers import (
|
||||
console,
|
||||
get_source_type_display,
|
||||
json_output_response,
|
||||
require_notebook,
|
||||
resolve_source_id,
|
||||
with_client,
|
||||
json_output_response,
|
||||
get_source_type_display,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -85,9 +85,7 @@ def source_list(ctx, notebook_id, json_output, client_auth):
|
|||
"title": src.title,
|
||||
"type": src.source_type,
|
||||
"url": src.url,
|
||||
"created_at": src.created_at.isoformat()
|
||||
if src.created_at
|
||||
else None,
|
||||
"created_at": src.created_at.isoformat() if src.created_at else None,
|
||||
}
|
||||
for i, src in enumerate(sources, 1)
|
||||
],
|
||||
|
|
@ -104,9 +102,7 @@ def source_list(ctx, notebook_id, json_output, client_auth):
|
|||
|
||||
for src in sources:
|
||||
type_display = get_source_type_display(src.source_type)
|
||||
created = (
|
||||
src.created_at.strftime("%Y-%m-%d %H:%M") if src.created_at else "-"
|
||||
)
|
||||
created = src.created_at.strftime("%Y-%m-%d %H:%M") if src.created_at else "-"
|
||||
table.add_row(src.id, src.title or "-", type_display, created)
|
||||
|
||||
console.print(table)
|
||||
|
|
@ -186,9 +182,7 @@ def source_add(ctx, content, notebook_id, source_type, title, mime_type, json_ou
|
|||
|
||||
async def _run():
|
||||
async with NotebookLMClient(client_auth) as client:
|
||||
if detected_type == "url":
|
||||
src = await client.sources.add_url(nb_id, content)
|
||||
elif detected_type == "youtube":
|
||||
if detected_type == "url" or detected_type == "youtube":
|
||||
src = await client.sources.add_url(nb_id, content)
|
||||
elif detected_type == "text":
|
||||
text_content = file_content if file_content is not None else content
|
||||
|
|
@ -242,9 +236,7 @@ def source_get(ctx, source_id, notebook_id, client_auth):
|
|||
if src:
|
||||
console.print(f"[bold cyan]Source:[/bold cyan] {src.id}")
|
||||
console.print(f"[bold]Title:[/bold] {src.title}")
|
||||
console.print(
|
||||
f"[bold]Type:[/bold] {get_source_type_display(src.source_type)}"
|
||||
)
|
||||
console.print(f"[bold]Type:[/bold] {get_source_type_display(src.source_type)}")
|
||||
if src.url:
|
||||
console.print(f"[bold]URL:[/bold] {src.url}")
|
||||
if src.created_at:
|
||||
|
|
@ -443,9 +435,7 @@ def source_add_research(
|
|||
|
||||
async def _run():
|
||||
async with NotebookLMClient(client_auth) as client:
|
||||
console.print(
|
||||
f"[yellow]Starting {mode} research on {search_source}...[/yellow]"
|
||||
)
|
||||
console.print(f"[yellow]Starting {mode} research on {search_source}...[/yellow]")
|
||||
result = await client.research.start(nb_id, query, search_source, mode)
|
||||
if not result:
|
||||
console.print("[red]Research failed to start[/red]")
|
||||
|
|
@ -479,9 +469,7 @@ def source_add_research(
|
|||
console.print(f"\n[green]Found {len(sources)} sources[/green]")
|
||||
|
||||
if import_all and sources and task_id:
|
||||
imported = await client.research.import_sources(
|
||||
nb_id, task_id, sources
|
||||
)
|
||||
imported = await client.research.import_sources(nb_id, task_id, sources)
|
||||
console.print(f"[green]Imported {len(imported)} sources[/green]")
|
||||
else:
|
||||
console.print(f"[yellow]Status: {status.get('status', 'unknown')}[/yellow]")
|
||||
|
|
@ -634,7 +622,7 @@ def source_wait(ctx, source_id, notebook_id, timeout, json_output, client_auth):
|
|||
notebooklm source add https://example.com
|
||||
# Subagent runs: notebooklm source wait <source_id>
|
||||
"""
|
||||
from ..types import SourceProcessingError, SourceTimeoutError, SourceNotFoundError
|
||||
from ..types import SourceNotFoundError, SourceProcessingError, SourceTimeoutError
|
||||
|
||||
nb_id = require_notebook(notebook_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,16 +20,15 @@ Example:
|
|||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .auth import AuthTokens
|
||||
from ._core import ClientCore, DEFAULT_TIMEOUT
|
||||
from ._notebooks import NotebooksAPI
|
||||
from ._sources import SourcesAPI
|
||||
from ._artifacts import ArtifactsAPI
|
||||
from ._chat import ChatAPI
|
||||
from ._research import ResearchAPI
|
||||
from ._core import DEFAULT_TIMEOUT, ClientCore
|
||||
from ._notebooks import NotebooksAPI
|
||||
from ._notes import NotesAPI
|
||||
from ._research import ResearchAPI
|
||||
from ._sources import SourcesAPI
|
||||
from .auth import AuthTokens
|
||||
|
||||
|
||||
class NotebookLMClient:
|
||||
|
|
@ -102,7 +101,7 @@ class NotebookLMClient:
|
|||
|
||||
@classmethod
|
||||
async def from_storage(
|
||||
cls, path: Optional[str] = None, timeout: float = DEFAULT_TIMEOUT
|
||||
cls, path: str | None = None, timeout: float = DEFAULT_TIMEOUT
|
||||
) -> "NotebookLMClient":
|
||||
"""Create a client from Playwright storage state file.
|
||||
|
||||
|
|
@ -146,9 +145,7 @@ class NotebookLMClient:
|
|||
# Check for redirect to login page
|
||||
final_url = str(response.url)
|
||||
if "accounts.google.com" in final_url:
|
||||
raise ValueError(
|
||||
"Authentication expired. Run 'notebooklm login' to re-authenticate."
|
||||
)
|
||||
raise ValueError("Authentication expired. Run 'notebooklm login' to re-authenticate.")
|
||||
|
||||
# Extract SNlM0e (CSRF token) - REQUIRED
|
||||
csrf_match = re.search(r'"SNlM0e":"([^"]+)"', response.text)
|
||||
|
|
|
|||
|
|
@ -32,17 +32,17 @@ from .auth import DEFAULT_STORAGE_PATH
|
|||
|
||||
# Import command groups from cli package
|
||||
from .cli import (
|
||||
source,
|
||||
artifact,
|
||||
generate,
|
||||
download,
|
||||
generate,
|
||||
note,
|
||||
skill,
|
||||
research,
|
||||
register_chat_commands,
|
||||
register_notebook_commands,
|
||||
# Register functions for top-level commands
|
||||
register_session_commands,
|
||||
register_notebook_commands,
|
||||
register_chat_commands,
|
||||
research,
|
||||
skill,
|
||||
source,
|
||||
)
|
||||
from .cli.grouped import SectionedGroup
|
||||
|
||||
|
|
|
|||
|
|
@ -1,36 +1,36 @@
|
|||
"""RPC protocol implementation for NotebookLM batchexecute API."""
|
||||
|
||||
from .decoder import (
|
||||
RPCError,
|
||||
collect_rpc_ids,
|
||||
decode_response,
|
||||
extract_rpc_result,
|
||||
parse_chunked_response,
|
||||
strip_anti_xssi,
|
||||
)
|
||||
from .encoder import build_request_body, encode_rpc_request
|
||||
from .types import (
|
||||
RPCMethod,
|
||||
BATCHEXECUTE_URL,
|
||||
QUERY_URL,
|
||||
UPLOAD_URL,
|
||||
StudioContentType,
|
||||
ArtifactStatus,
|
||||
AudioFormat,
|
||||
AudioLength,
|
||||
VideoFormat,
|
||||
VideoStyle,
|
||||
QuizQuantity,
|
||||
QuizDifficulty,
|
||||
InfographicOrientation,
|
||||
InfographicDetail,
|
||||
SlideDeckFormat,
|
||||
SlideDeckLength,
|
||||
ReportFormat,
|
||||
ChatGoal,
|
||||
ChatResponseLength,
|
||||
DriveMimeType,
|
||||
ExportType,
|
||||
)
|
||||
from .encoder import encode_rpc_request, build_request_body
|
||||
from .decoder import (
|
||||
strip_anti_xssi,
|
||||
parse_chunked_response,
|
||||
extract_rpc_result,
|
||||
collect_rpc_ids,
|
||||
decode_response,
|
||||
RPCError,
|
||||
InfographicDetail,
|
||||
InfographicOrientation,
|
||||
QuizDifficulty,
|
||||
QuizQuantity,
|
||||
ReportFormat,
|
||||
RPCMethod,
|
||||
SlideDeckFormat,
|
||||
SlideDeckLength,
|
||||
StudioContentType,
|
||||
VideoFormat,
|
||||
VideoStyle,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -14,9 +14,9 @@ class RPCError(Exception):
|
|||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
rpc_id: Optional[str] = None,
|
||||
code: Optional[Any] = None,
|
||||
found_ids: Optional[list[str]] = None,
|
||||
rpc_id: str | None = None,
|
||||
code: Any | None = None,
|
||||
found_ids: list[str] | None = None,
|
||||
):
|
||||
self.rpc_id = rpc_id
|
||||
self.code = code
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Encode RPC requests for NotebookLM batchexecute API."""
|
||||
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from .types import RPCMethod
|
||||
|
|
@ -33,8 +33,8 @@ def encode_rpc_request(method: RPCMethod, params: list[Any]) -> list:
|
|||
|
||||
def build_request_body(
|
||||
rpc_request: list,
|
||||
csrf_token: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
csrf_token: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build form-encoded request body for batchexecute.
|
||||
|
|
@ -67,8 +67,8 @@ def build_request_body(
|
|||
def build_url_params(
|
||||
rpc_method: RPCMethod,
|
||||
source_path: str = "/",
|
||||
session_id: Optional[str] = None,
|
||||
bl: Optional[str] = None,
|
||||
session_id: str | None = None,
|
||||
bl: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Build URL query parameters for batchexecute request.
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
# NotebookLM API endpoints
|
||||
BATCHEXECUTE_URL = "https://notebooklm.google.com/_/LabsTailwindUi/data/batchexecute"
|
||||
QUERY_URL = "https://notebooklm.google.com/_/LabsTailwindUi/data/google.internal.labs.tailwind.orchestration.v1.LabsTailwindOrchestrationService/GenerateFreeFormStreamed"
|
||||
|
|
@ -93,7 +92,9 @@ class StudioContentType(int, Enum):
|
|||
"""
|
||||
|
||||
AUDIO = 1
|
||||
REPORT = 2 # Includes: Briefing Doc, Study Guide, Blog Post, White Paper, Research Proposal, etc.
|
||||
REPORT = (
|
||||
2 # Includes: Briefing Doc, Study Guide, Blog Post, White Paper, Research Proposal, etc.
|
||||
)
|
||||
VIDEO = 3
|
||||
QUIZ = 4 # Also used for flashcards
|
||||
QUIZ_FLASHCARD = 4 # Alias for backward compatibility
|
||||
|
|
|
|||
|
|
@ -14,23 +14,23 @@ from typing import Any, Optional
|
|||
|
||||
# Re-export enums from rpc/types.py for convenience
|
||||
from .rpc.types import (
|
||||
StudioContentType,
|
||||
AudioFormat,
|
||||
AudioLength,
|
||||
VideoFormat,
|
||||
VideoStyle,
|
||||
QuizQuantity,
|
||||
QuizDifficulty,
|
||||
InfographicOrientation,
|
||||
InfographicDetail,
|
||||
SlideDeckFormat,
|
||||
SlideDeckLength,
|
||||
ReportFormat,
|
||||
ChatGoal,
|
||||
ChatResponseLength,
|
||||
DriveMimeType,
|
||||
ExportType,
|
||||
InfographicDetail,
|
||||
InfographicOrientation,
|
||||
QuizDifficulty,
|
||||
QuizQuantity,
|
||||
ReportFormat,
|
||||
SlideDeckFormat,
|
||||
SlideDeckLength,
|
||||
SourceStatus,
|
||||
StudioContentType,
|
||||
VideoFormat,
|
||||
VideoStyle,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -99,7 +99,7 @@ class Notebook:
|
|||
|
||||
id: str
|
||||
title: str
|
||||
created_at: Optional[datetime] = None
|
||||
created_at: datetime | None = None
|
||||
sources_count: int = 0
|
||||
is_owner: bool = True
|
||||
|
||||
|
|
@ -197,7 +197,7 @@ class SourceTimeoutError(SourceError):
|
|||
last_status: The last observed status before timeout.
|
||||
"""
|
||||
|
||||
def __init__(self, source_id: str, timeout: float, last_status: Optional[int] = None):
|
||||
def __init__(self, source_id: str, timeout: float, last_status: int | None = None):
|
||||
self.source_id = source_id
|
||||
self.timeout = timeout
|
||||
self.last_status = last_status
|
||||
|
|
@ -231,10 +231,10 @@ class Source:
|
|||
"""
|
||||
|
||||
id: str
|
||||
title: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
title: str | None = None
|
||||
url: str | None = None
|
||||
source_type: str = "text"
|
||||
created_at: Optional[datetime] = None
|
||||
created_at: datetime | None = None
|
||||
status: int = SourceStatus.READY # Default to READY (2)
|
||||
|
||||
@property
|
||||
|
|
@ -253,9 +253,7 @@ class Source:
|
|||
return self.status == SourceStatus.ERROR
|
||||
|
||||
@classmethod
|
||||
def from_api_response(
|
||||
cls, data: list[Any], notebook_id: Optional[str] = None
|
||||
) -> "Source":
|
||||
def from_api_response(cls, data: list[Any], notebook_id: str | None = None) -> "Source":
|
||||
"""Parse source data from various API response formats.
|
||||
|
||||
The API returns different structures for different operations:
|
||||
|
|
@ -294,12 +292,7 @@ class Source:
|
|||
if len(entry[2]) > 7 and isinstance(entry[2][7], list):
|
||||
url = entry[2][7][0] if entry[2][7] else None
|
||||
|
||||
return cls(
|
||||
id=str(source_id),
|
||||
title=title,
|
||||
url=url,
|
||||
source_type="text"
|
||||
)
|
||||
return cls(id=str(source_id), title=title, url=url, source_type="text")
|
||||
|
||||
# Deeply nested: continue with URL extraction
|
||||
url = None
|
||||
|
|
@ -309,14 +302,14 @@ class Source:
|
|||
if isinstance(url_list, list) and len(url_list) > 0:
|
||||
url = url_list[0]
|
||||
if not url and len(entry[2]) > 0:
|
||||
if isinstance(entry[2][0], str) and entry[2][0].startswith('http'):
|
||||
if isinstance(entry[2][0], str) and entry[2][0].startswith("http"):
|
||||
url = entry[2][0]
|
||||
|
||||
# Determine source type
|
||||
source_type = "text"
|
||||
if url:
|
||||
source_type = "youtube" if "youtube.com" in url or "youtu.be" in url else "url"
|
||||
elif title and (title.endswith('.pdf') or title.endswith('.txt')):
|
||||
elif title and (title.endswith(".pdf") or title.endswith(".txt")):
|
||||
source_type = "text_file"
|
||||
|
||||
return cls(
|
||||
|
|
@ -350,9 +343,9 @@ class Artifact:
|
|||
title: str
|
||||
artifact_type: int # StudioContentType enum value
|
||||
status: int # 1=processing, 3=completed
|
||||
created_at: Optional[datetime] = None
|
||||
url: Optional[str] = None
|
||||
variant: Optional[int] = None # For type 4: 1=flashcards, 2=quiz
|
||||
created_at: datetime | None = None
|
||||
url: str | None = None
|
||||
variant: int | None = None # For type 4: 1=flashcards, 2=quiz
|
||||
|
||||
@classmethod
|
||||
def from_api_response(cls, data: list[Any]) -> "Artifact":
|
||||
|
|
@ -469,7 +462,7 @@ class Artifact:
|
|||
return self.artifact_type == 4 and self.variant == 1
|
||||
|
||||
@property
|
||||
def report_subtype(self) -> Optional[str]:
|
||||
def report_subtype(self) -> str | None:
|
||||
"""Get the report subtype for type 2 artifacts.
|
||||
|
||||
Returns:
|
||||
|
|
@ -493,10 +486,10 @@ class GenerationStatus:
|
|||
|
||||
task_id: str
|
||||
status: str # "pending", "in_progress", "completed", "failed"
|
||||
url: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
error_code: Optional[str] = None # e.g., "USER_DISPLAYABLE_ERROR" for rate limits
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
url: str | None = None
|
||||
error: str | None = None
|
||||
error_code: str | None = None # e.g., "USER_DISPLAYABLE_ERROR" for rate limits
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
@property
|
||||
def is_complete(self) -> bool:
|
||||
|
|
@ -578,7 +571,7 @@ class Note:
|
|||
notebook_id: str
|
||||
title: str
|
||||
content: str
|
||||
created_at: Optional[datetime] = None
|
||||
created_at: datetime | None = None
|
||||
|
||||
@classmethod
|
||||
def from_api_response(cls, data: list[Any], notebook_id: str) -> "Note":
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"""Shared test fixtures."""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
|
||||
from notebooklm.rpc import RPCMethod
|
||||
|
||||
|
|
@ -76,7 +76,7 @@ def build_rpc_response():
|
|||
data: The response data to encode.
|
||||
"""
|
||||
|
||||
def _build(rpc_id: Union[RPCMethod, str], data) -> str:
|
||||
def _build(rpc_id: RPCMethod | str, data) -> str:
|
||||
# Convert RPCMethod to string value if needed
|
||||
rpc_id_str = rpc_id.value if isinstance(rpc_id, RPCMethod) else rpc_id
|
||||
inner = json.dumps(data)
|
||||
|
|
|
|||
|
|
@ -2,27 +2,28 @@
|
|||
|
||||
import os
|
||||
import warnings
|
||||
import pytest
|
||||
import httpx
|
||||
from collections.abc import AsyncGenerator
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
# Load .env file if python-dotenv is available
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
except ImportError:
|
||||
pass # python-dotenv not installed, rely on shell environment
|
||||
|
||||
from notebooklm import NotebookLMClient
|
||||
from notebooklm.auth import (
|
||||
load_auth_from_storage,
|
||||
AuthTokens,
|
||||
extract_csrf_from_html,
|
||||
extract_session_id_from_html,
|
||||
AuthTokens,
|
||||
load_auth_from_storage,
|
||||
)
|
||||
from notebooklm.paths import get_home_dir
|
||||
from notebooklm import NotebookLMClient
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Constants
|
||||
|
|
@ -187,8 +188,6 @@ async def cleanup_notebooks(created_notebooks, auth_tokens):
|
|||
warnings.warn(f"Failed to cleanup notebook {nb_id}: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Notebook Fixtures
|
||||
# =============================================================================
|
||||
|
|
@ -203,6 +202,7 @@ async def temp_notebook(client, created_notebooks, cleanup_notebooks):
|
|||
"""
|
||||
import asyncio
|
||||
from uuid import uuid4
|
||||
|
||||
notebook = await client.notebooks.create(f"Test-{uuid4().hex[:8]}")
|
||||
created_notebooks.append(notebook.id)
|
||||
|
||||
|
|
@ -332,7 +332,7 @@ async def _cleanup_generation_notebook(client: NotebookLMClient, notebook_id: st
|
|||
notes = await client.notes.list(notebook_id)
|
||||
for note in notes:
|
||||
# Skip if no id or if it's a pinned system note
|
||||
if note.id and not getattr(note, 'pinned', False):
|
||||
if note.id and not getattr(note, "pinned", False):
|
||||
try:
|
||||
await client.notes.delete(notebook_id, note.id)
|
||||
except Exception:
|
||||
|
|
@ -432,5 +432,3 @@ async def generation_notebook_id(client):
|
|||
await client.notebooks.delete(notebook_id)
|
||||
except Exception as e:
|
||||
warnings.warn(f"Failed to delete generation notebook {notebook_id}: {e}")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,10 +8,13 @@ Generation tests are in test_generation.py. This file contains:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from .conftest import requires_auth, assert_generation_started
|
||||
|
||||
from notebooklm import Artifact, ReportSuggestion
|
||||
|
||||
from .conftest import assert_generation_started, requires_auth
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestArtifactRetrieval:
|
||||
|
|
|
|||
|
|
@ -1,31 +1,31 @@
|
|||
import os
|
||||
import tempfile
|
||||
import pytest
|
||||
from .conftest import requires_auth
|
||||
from notebooklm import Artifact
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import requires_auth
|
||||
|
||||
# Magic bytes for file type verification
|
||||
PNG_MAGIC = b'\x89PNG\r\n\x1a\n'
|
||||
PDF_MAGIC = b'%PDF'
|
||||
MP4_FTYP = b'ftyp' # At offset 4
|
||||
PNG_MAGIC = b"\x89PNG\r\n\x1a\n"
|
||||
PDF_MAGIC = b"%PDF"
|
||||
MP4_FTYP = b"ftyp" # At offset 4
|
||||
|
||||
|
||||
def is_png(path: str) -> bool:
|
||||
"""Check if file is a valid PNG by magic bytes."""
|
||||
with open(path, 'rb') as f:
|
||||
with open(path, "rb") as f:
|
||||
return f.read(8) == PNG_MAGIC
|
||||
|
||||
|
||||
def is_pdf(path: str) -> bool:
|
||||
"""Check if file is a valid PDF by magic bytes."""
|
||||
with open(path, 'rb') as f:
|
||||
with open(path, "rb") as f:
|
||||
return f.read(4) == PDF_MAGIC
|
||||
|
||||
|
||||
def is_mp4(path: str) -> bool:
|
||||
"""Check if file is a valid MP4 by magic bytes."""
|
||||
with open(path, 'rb') as f:
|
||||
with open(path, "rb") as f:
|
||||
header = f.read(12)
|
||||
# MP4 has 'ftyp' at offset 4
|
||||
return len(header) >= 8 and header[4:8] == MP4_FTYP
|
||||
|
|
@ -106,7 +106,9 @@ class TestDownloadSlideDeck:
|
|||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
output_path = os.path.join(tmpdir, "slides.pdf")
|
||||
try:
|
||||
result = await client.artifacts.download_slide_deck(read_only_notebook_id, output_path)
|
||||
result = await client.artifacts.download_slide_deck(
|
||||
read_only_notebook_id, output_path
|
||||
)
|
||||
assert result == output_path
|
||||
assert os.path.exists(output_path)
|
||||
assert os.path.getsize(output_path) > 0
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import os
|
||||
import tempfile
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import requires_auth
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,37 +12,35 @@ Notebook lifecycle:
|
|||
"""
|
||||
|
||||
import pytest
|
||||
from .conftest import requires_auth, assert_generation_started
|
||||
|
||||
from notebooklm import (
|
||||
AudioFormat,
|
||||
AudioLength,
|
||||
VideoFormat,
|
||||
VideoStyle,
|
||||
QuizQuantity,
|
||||
QuizDifficulty,
|
||||
InfographicOrientation,
|
||||
InfographicDetail,
|
||||
InfographicOrientation,
|
||||
QuizDifficulty,
|
||||
QuizQuantity,
|
||||
SlideDeckFormat,
|
||||
SlideDeckLength,
|
||||
VideoFormat,
|
||||
VideoStyle,
|
||||
)
|
||||
|
||||
from .conftest import assert_generation_started, requires_auth
|
||||
|
||||
|
||||
@requires_auth
|
||||
class TestAudioGeneration:
|
||||
"""Audio generation tests."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_audio_default(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_audio_default(self, client, generation_notebook_id):
|
||||
"""Test audio generation with true defaults."""
|
||||
result = await client.artifacts.generate_audio(generation_notebook_id)
|
||||
assert_generation_started(result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_audio_brief(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_audio_brief(self, client, generation_notebook_id):
|
||||
"""Test audio generation with non-default format to verify param encoding."""
|
||||
result = await client.artifacts.generate_audio(
|
||||
generation_notebook_id,
|
||||
|
|
@ -52,9 +50,7 @@ class TestAudioGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_audio_deep_dive_long(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_audio_deep_dive_long(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_audio(
|
||||
generation_notebook_id,
|
||||
audio_format=AudioFormat.DEEP_DIVE,
|
||||
|
|
@ -64,9 +60,7 @@ class TestAudioGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_audio_brief_short(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_audio_brief_short(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_audio(
|
||||
generation_notebook_id,
|
||||
audio_format=AudioFormat.BRIEF,
|
||||
|
|
@ -76,9 +70,7 @@ class TestAudioGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_audio_critique(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_audio_critique(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_audio(
|
||||
generation_notebook_id,
|
||||
audio_format=AudioFormat.CRITIQUE,
|
||||
|
|
@ -87,9 +79,7 @@ class TestAudioGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_audio_debate(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_audio_debate(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_audio(
|
||||
generation_notebook_id,
|
||||
audio_format=AudioFormat.DEBATE,
|
||||
|
|
@ -98,9 +88,7 @@ class TestAudioGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_audio_with_language(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_audio_with_language(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_audio(
|
||||
generation_notebook_id,
|
||||
language="en",
|
||||
|
|
@ -113,9 +101,7 @@ class TestVideoGeneration:
|
|||
"""Video generation tests."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_video_default(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_video_default(self, client, generation_notebook_id):
|
||||
"""Test video generation with non-default style to verify param encoding."""
|
||||
result = await client.artifacts.generate_video(
|
||||
generation_notebook_id,
|
||||
|
|
@ -125,9 +111,7 @@ class TestVideoGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_video_explainer_anime(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_video_explainer_anime(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_video(
|
||||
generation_notebook_id,
|
||||
video_format=VideoFormat.EXPLAINER,
|
||||
|
|
@ -137,9 +121,7 @@ class TestVideoGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_video_brief_whiteboard(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_video_brief_whiteboard(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_video(
|
||||
generation_notebook_id,
|
||||
video_format=VideoFormat.BRIEF,
|
||||
|
|
@ -149,9 +131,7 @@ class TestVideoGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_video_with_instructions(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_video_with_instructions(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_video(
|
||||
generation_notebook_id,
|
||||
video_format=VideoFormat.EXPLAINER,
|
||||
|
|
@ -162,9 +142,7 @@ class TestVideoGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_video_kawaii_style(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_video_kawaii_style(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_video(
|
||||
generation_notebook_id,
|
||||
video_style=VideoStyle.KAWAII,
|
||||
|
|
@ -173,9 +151,7 @@ class TestVideoGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_video_watercolor_style(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_video_watercolor_style(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_video(
|
||||
generation_notebook_id,
|
||||
video_style=VideoStyle.WATERCOLOR,
|
||||
|
|
@ -184,9 +160,7 @@ class TestVideoGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_video_auto_style(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_video_auto_style(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_video(
|
||||
generation_notebook_id,
|
||||
video_style=VideoStyle.AUTO_SELECT,
|
||||
|
|
@ -199,9 +173,7 @@ class TestQuizGeneration:
|
|||
"""Quiz generation tests."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_quiz_default(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_quiz_default(self, client, generation_notebook_id):
|
||||
"""Test quiz generation with non-default difficulty to verify param encoding."""
|
||||
result = await client.artifacts.generate_quiz(
|
||||
generation_notebook_id,
|
||||
|
|
@ -211,9 +183,7 @@ class TestQuizGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_quiz_with_options(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_quiz_with_options(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_quiz(
|
||||
generation_notebook_id,
|
||||
quantity=QuizQuantity.MORE,
|
||||
|
|
@ -224,9 +194,7 @@ class TestQuizGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_quiz_fewer_easy(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_quiz_fewer_easy(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_quiz(
|
||||
generation_notebook_id,
|
||||
quantity=QuizQuantity.FEWER,
|
||||
|
|
@ -240,9 +208,7 @@ class TestFlashcardsGeneration:
|
|||
"""Flashcards generation tests."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_flashcards_default(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_flashcards_default(self, client, generation_notebook_id):
|
||||
"""Test flashcards generation with non-default quantity to verify param encoding."""
|
||||
result = await client.artifacts.generate_flashcards(
|
||||
generation_notebook_id,
|
||||
|
|
@ -252,9 +218,7 @@ class TestFlashcardsGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_flashcards_with_options(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_flashcards_with_options(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_flashcards(
|
||||
generation_notebook_id,
|
||||
quantity=QuizQuantity.STANDARD,
|
||||
|
|
@ -269,9 +233,7 @@ class TestInfographicGeneration:
|
|||
"""Infographic generation tests."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_infographic_default(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_infographic_default(self, client, generation_notebook_id):
|
||||
"""Test infographic generation with non-default orientation to verify param encoding."""
|
||||
result = await client.artifacts.generate_infographic(
|
||||
generation_notebook_id,
|
||||
|
|
@ -281,9 +243,7 @@ class TestInfographicGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_infographic_portrait_detailed(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_infographic_portrait_detailed(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_infographic(
|
||||
generation_notebook_id,
|
||||
orientation=InfographicOrientation.PORTRAIT,
|
||||
|
|
@ -294,9 +254,7 @@ class TestInfographicGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_infographic_square_concise(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_infographic_square_concise(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_infographic(
|
||||
generation_notebook_id,
|
||||
orientation=InfographicOrientation.SQUARE,
|
||||
|
|
@ -306,9 +264,7 @@ class TestInfographicGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_infographic_landscape(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_infographic_landscape(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_infographic(
|
||||
generation_notebook_id,
|
||||
orientation=InfographicOrientation.LANDSCAPE,
|
||||
|
|
@ -321,9 +277,7 @@ class TestSlideDeckGeneration:
|
|||
"""Slide deck generation tests."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_slide_deck_default(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_slide_deck_default(self, client, generation_notebook_id):
|
||||
"""Test slide deck generation with non-default format to verify param encoding."""
|
||||
result = await client.artifacts.generate_slide_deck(
|
||||
generation_notebook_id,
|
||||
|
|
@ -333,9 +287,7 @@ class TestSlideDeckGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_slide_deck_detailed(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_slide_deck_detailed(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_slide_deck(
|
||||
generation_notebook_id,
|
||||
slide_format=SlideDeckFormat.DETAILED_DECK,
|
||||
|
|
@ -346,9 +298,7 @@ class TestSlideDeckGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_slide_deck_presenter_short(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_slide_deck_presenter_short(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_slide_deck(
|
||||
generation_notebook_id,
|
||||
slide_format=SlideDeckFormat.PRESENTER_SLIDES,
|
||||
|
|
@ -362,9 +312,7 @@ class TestDataTableGeneration:
|
|||
"""Data table generation tests."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_data_table_default(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_data_table_default(self, client, generation_notebook_id):
|
||||
"""Test data table generation with instructions to verify param encoding."""
|
||||
result = await client.artifacts.generate_data_table(
|
||||
generation_notebook_id,
|
||||
|
|
@ -374,9 +322,7 @@ class TestDataTableGeneration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.variants
|
||||
async def test_generate_data_table_with_instructions(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_data_table_with_instructions(self, client, generation_notebook_id):
|
||||
result = await client.artifacts.generate_data_table(
|
||||
generation_notebook_id,
|
||||
instructions="Create a comparison table of key concepts",
|
||||
|
|
@ -408,9 +354,7 @@ class TestStudyGuideGeneration:
|
|||
"""Study guide generation tests."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_study_guide(
|
||||
self, client, generation_notebook_id
|
||||
):
|
||||
async def test_generate_study_guide(self, client, generation_notebook_id):
|
||||
"""Test study guide generation."""
|
||||
result = await client.artifacts.generate_study_guide(generation_notebook_id)
|
||||
assert_generation_started(result)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import pytest
|
||||
|
||||
from notebooklm import ChatGoal, ChatMode, Notebook, NotebookDescription
|
||||
|
||||
from .conftest import requires_auth
|
||||
from notebooklm import Notebook, NotebookDescription, ChatMode, ChatGoal
|
||||
|
||||
|
||||
@requires_auth
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""E2E tests for NotesAPI."""
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import requires_auth
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,8 +5,10 @@ polling for results, and importing discovered sources.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from .conftest import requires_auth, POLL_INTERVAL, POLL_TIMEOUT
|
||||
|
||||
from .conftest import POLL_INTERVAL, POLL_TIMEOUT, requires_auth
|
||||
|
||||
|
||||
@requires_auth
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from notebooklm import Source, SourceStatus
|
||||
|
||||
from .conftest import requires_auth
|
||||
from notebooklm import Source, SourceStatus, SourceTimeoutError
|
||||
|
||||
|
||||
@requires_auth
|
||||
|
|
@ -27,9 +30,7 @@ class TestSourceOperations:
|
|||
@pytest.mark.asyncio
|
||||
async def test_add_url_source(self, client, temp_notebook):
|
||||
"""Test adding a URL source to an owned notebook."""
|
||||
source = await client.sources.add_url(
|
||||
temp_notebook.id, "https://httpbin.org/html"
|
||||
)
|
||||
source = await client.sources.add_url(temp_notebook.id, "https://httpbin.org/html")
|
||||
assert isinstance(source, Source)
|
||||
assert source.id is not None
|
||||
# URL may or may not be returned in response
|
||||
|
|
@ -59,9 +60,7 @@ class TestSourceOperations:
|
|||
assert isinstance(source, Source)
|
||||
|
||||
# Rename
|
||||
renamed = await client.sources.rename(
|
||||
temp_notebook.id, source.id, "Renamed Test Source"
|
||||
)
|
||||
renamed = await client.sources.rename(temp_notebook.id, source.id, "Renamed Test Source")
|
||||
assert isinstance(renamed, Source)
|
||||
assert renamed.title == "Renamed Test Source"
|
||||
# No need to restore - temp_notebook is deleted after test
|
||||
|
|
@ -135,9 +134,7 @@ class TestSourceMutations:
|
|||
async def test_refresh_source(self, client, temp_notebook):
|
||||
"""Test refreshing a URL source."""
|
||||
# Add a URL source
|
||||
source = await client.sources.add_url(
|
||||
temp_notebook.id, "https://httpbin.org/html"
|
||||
)
|
||||
source = await client.sources.add_url(temp_notebook.id, "https://httpbin.org/html")
|
||||
assert source.id is not None
|
||||
|
||||
# Refresh it
|
||||
|
|
@ -151,9 +148,7 @@ class TestSourceMutations:
|
|||
async def test_check_freshness(self, client, temp_notebook):
|
||||
"""Test checking source freshness."""
|
||||
# Add a URL source
|
||||
source = await client.sources.add_url(
|
||||
temp_notebook.id, "https://httpbin.org/html"
|
||||
)
|
||||
source = await client.sources.add_url(temp_notebook.id, "https://httpbin.org/html")
|
||||
assert source.id is not None
|
||||
|
||||
await asyncio.sleep(2) # Wait for processing
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
"""Shared fixtures for integration tests."""
|
||||
|
||||
import json
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -34,7 +33,7 @@ def build_rpc_response():
|
|||
data: The response data to encode.
|
||||
"""
|
||||
|
||||
def _build(rpc_id: Union[RPCMethod, str], data) -> str:
|
||||
def _build(rpc_id: RPCMethod | str, data) -> str:
|
||||
# Convert RPCMethod to string value if needed
|
||||
rpc_id_str = rpc_id.value if isinstance(rpc_id, RPCMethod) else rpc_id
|
||||
inner = json.dumps(data)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import pytest
|
|||
from pytest_httpx import HTTPXMock
|
||||
|
||||
from notebooklm import NotebookLMClient
|
||||
from notebooklm.rpc import AudioFormat, AudioLength, VideoFormat, VideoStyle, RPCError, RPCMethod
|
||||
from notebooklm.rpc import AudioFormat, AudioLength, RPCError, RPCMethod, VideoFormat, VideoStyle
|
||||
|
||||
|
||||
class TestStudioContent:
|
||||
|
|
@ -529,7 +529,15 @@ class TestArtifactsAPI:
|
|||
RPCMethod.LIST_ARTIFACTS,
|
||||
[
|
||||
["art_001", "Quiz", 4, None, 3, None, [None, None, None, None, None, None, 2]],
|
||||
["art_002", "Flashcards", 4, None, 3, None, [None, None, None, None, None, None, 1]],
|
||||
[
|
||||
"art_002",
|
||||
"Flashcards",
|
||||
4,
|
||||
None,
|
||||
3,
|
||||
None,
|
||||
[None, None, None, None, None, None, 1],
|
||||
],
|
||||
],
|
||||
)
|
||||
httpx_mock.add_response(content=response.encode())
|
||||
|
|
@ -569,7 +577,15 @@ class TestArtifactsAPI:
|
|||
RPCMethod.LIST_ARTIFACTS,
|
||||
[
|
||||
["art_001", "Quiz", 4, None, 3, None, [None, None, None, None, None, None, 2]],
|
||||
["art_002", "Flashcards", 4, None, 3, None, [None, None, None, None, None, None, 1]],
|
||||
[
|
||||
"art_002",
|
||||
"Flashcards",
|
||||
4,
|
||||
None,
|
||||
3,
|
||||
None,
|
||||
[None, None, None, None, None, None, 1],
|
||||
],
|
||||
],
|
||||
)
|
||||
httpx_mock.add_response(content=response.encode())
|
||||
|
|
|
|||
|
|
@ -4,8 +4,7 @@ import pytest
|
|||
from pytest_httpx import HTTPXMock
|
||||
|
||||
from notebooklm import NotebookLMClient
|
||||
from notebooklm.rpc import RPCMethod
|
||||
from notebooklm.rpc import ChatGoal, ChatResponseLength
|
||||
from notebooklm.rpc import ChatGoal, ChatResponseLength, RPCMethod
|
||||
from notebooklm.types import ChatMode
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ to avoid asyncio event loop conflicts with pytest-asyncio.
|
|||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from notebooklm.cli.download_helpers import select_artifact, artifact_title_to_filename
|
||||
|
||||
from notebooklm.cli.download_helpers import artifact_title_to_filename, select_artifact
|
||||
|
||||
|
||||
class TestArtifactSelection:
|
||||
|
|
@ -19,7 +19,7 @@ class TestArtifactSelection:
|
|||
{"id": "a2", "title": "Meeting Notes", "created_at": 2000},
|
||||
{"id": "a3", "title": "Debate Round 3", "created_at": 3000}, # Latest "debate"
|
||||
{"id": "a4", "title": "Debate Round 2", "created_at": 2500},
|
||||
{"id": "a5", "title": "Overview", "created_at": 4000}, # Latest overall
|
||||
{"id": "a5", "title": "Overview", "created_at": 4000}, # Latest overall
|
||||
]
|
||||
|
||||
selected, reason = select_artifact(artifacts, latest=True, name="debate")
|
||||
|
|
@ -32,9 +32,9 @@ class TestArtifactSelection:
|
|||
def test_filter_then_select_earliest(self):
|
||||
"""Should apply name filter BEFORE selecting earliest."""
|
||||
artifacts = [
|
||||
{"id": "a1", "title": "Introduction", "created_at": 1000}, # Earliest overall
|
||||
{"id": "a1", "title": "Introduction", "created_at": 1000}, # Earliest overall
|
||||
{"id": "a2", "title": "Chapter 2", "created_at": 3000},
|
||||
{"id": "a3", "title": "Chapter 1", "created_at": 2000}, # Earliest "chapter"
|
||||
{"id": "a3", "title": "Chapter 1", "created_at": 2000}, # Earliest "chapter"
|
||||
{"id": "a4", "title": "Chapter 3", "created_at": 4000},
|
||||
{"id": "a5", "title": "Conclusion", "created_at": 5000},
|
||||
]
|
||||
|
|
@ -175,7 +175,7 @@ class TestFilenameGeneration:
|
|||
assert "/" not in filename
|
||||
assert ":" not in filename
|
||||
assert '"' not in filename
|
||||
assert filename == 'Audio_ Part 1 _ _Main_.mp3'
|
||||
assert filename == "Audio_ Part 1 _ _Main_.mp3"
|
||||
|
||||
def test_handle_duplicates(self):
|
||||
"""Should add (2), (3) suffixes for duplicates."""
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import pytest
|
||||
from pytest_httpx import HTTPXMock
|
||||
|
||||
from notebooklm import NotebookLMClient, Notebook
|
||||
from notebooklm import Notebook, NotebookLMClient
|
||||
from notebooklm.rpc import RPCMethod
|
||||
|
||||
|
||||
|
|
@ -199,9 +199,7 @@ class TestSummary:
|
|||
httpx_mock: HTTPXMock,
|
||||
build_rpc_response,
|
||||
):
|
||||
response = build_rpc_response(
|
||||
RPCMethod.SUMMARIZE, ["Summary of the notebook content..."]
|
||||
)
|
||||
response = build_rpc_response(RPCMethod.SUMMARIZE, ["Summary of the notebook content..."])
|
||||
httpx_mock.add_response(content=response.encode())
|
||||
|
||||
async with NotebookLMClient(auth_tokens) as client:
|
||||
|
|
@ -257,7 +255,16 @@ class TestRenameNotebook:
|
|||
# Get notebook response after rename
|
||||
get_response = build_rpc_response(
|
||||
RPCMethod.GET_NOTEBOOK,
|
||||
[["Renamed", [], "nb_123", "📘", None, [None, None, None, None, None, [1704067200, 0]]]],
|
||||
[
|
||||
[
|
||||
"Renamed",
|
||||
[],
|
||||
"nb_123",
|
||||
"📘",
|
||||
None,
|
||||
[None, None, None, None, None, [1704067200, 0]],
|
||||
]
|
||||
],
|
||||
)
|
||||
httpx_mock.add_response(content=get_response.encode())
|
||||
|
||||
|
|
@ -364,10 +371,12 @@ class TestNotebooksAPIAdditional:
|
|||
RPCMethod.SUMMARIZE,
|
||||
[
|
||||
["This notebook covers AI research."],
|
||||
[[
|
||||
["What are the main findings?", "Explain the key findings"],
|
||||
["How was the study conducted?", "Describe methodology"],
|
||||
]],
|
||||
[
|
||||
[
|
||||
["What are the main findings?", "Explain the key findings"],
|
||||
["How was the study conducted?", "Describe methodology"],
|
||||
]
|
||||
],
|
||||
],
|
||||
)
|
||||
httpx_mock.add_response(content=response.encode())
|
||||
|
|
@ -464,11 +473,13 @@ class TestNotebookEdgeCases:
|
|||
RPCMethod.SUMMARIZE,
|
||||
[
|
||||
["Summary"],
|
||||
[[
|
||||
["Valid question", "Valid prompt"],
|
||||
["Only question"], # Missing prompt
|
||||
"not a list", # Not a list
|
||||
]],
|
||||
[
|
||||
[
|
||||
["Valid question", "Valid prompt"],
|
||||
["Only question"], # Missing prompt
|
||||
"not a list", # Not a list
|
||||
]
|
||||
],
|
||||
],
|
||||
)
|
||||
httpx_mock.add_response(content=response.encode())
|
||||
|
|
|
|||
|
|
@ -68,7 +68,10 @@ class TestNotesAPI:
|
|||
[
|
||||
[
|
||||
["note_001", ["note_001", "Regular note content", None, None, "Regular Note"]],
|
||||
["mm_001", ["mm_001", '{"title":"Mind Map","children":[]}', None, None, "Mind Map"]],
|
||||
[
|
||||
"mm_001",
|
||||
["mm_001", '{"title":"Mind Map","children":[]}', None, None, "Mind Map"],
|
||||
],
|
||||
]
|
||||
],
|
||||
)
|
||||
|
|
@ -204,7 +207,10 @@ class TestNotesAPI:
|
|||
[
|
||||
[
|
||||
["note_001", ["note_001", "Regular note", None, None, "Note"]],
|
||||
["mm_001", ["mm_001", '{"title":"Mind Map 1","children":[]}', None, None, "MM1"]],
|
||||
[
|
||||
"mm_001",
|
||||
["mm_001", '{"title":"Mind Map 1","children":[]}', None, None, "MM1"],
|
||||
],
|
||||
["mm_002", ["mm_002", '{"nodes":[{"id":"1"}]}', None, None, "MM2"]],
|
||||
]
|
||||
],
|
||||
|
|
|
|||
|
|
@ -17,9 +17,7 @@ class TestResearchAPI:
|
|||
build_rpc_response,
|
||||
):
|
||||
"""Test starting fast web research."""
|
||||
response = build_rpc_response(
|
||||
"Ljjv0c", ["task_123", "report_456"]
|
||||
)
|
||||
response = build_rpc_response("Ljjv0c", ["task_123", "report_456"])
|
||||
httpx_mock.add_response(content=response.encode())
|
||||
|
||||
async with NotebookLMClient(auth_tokens) as client:
|
||||
|
|
@ -43,9 +41,7 @@ class TestResearchAPI:
|
|||
build_rpc_response,
|
||||
):
|
||||
"""Test starting fast drive research."""
|
||||
response = build_rpc_response(
|
||||
"Ljjv0c", ["task_789", None]
|
||||
)
|
||||
response = build_rpc_response("Ljjv0c", ["task_789", None])
|
||||
httpx_mock.add_response(content=response.encode())
|
||||
|
||||
async with NotebookLMClient(auth_tokens) as client:
|
||||
|
|
@ -65,15 +61,11 @@ class TestResearchAPI:
|
|||
build_rpc_response,
|
||||
):
|
||||
"""Test starting deep web research."""
|
||||
response = build_rpc_response(
|
||||
"QA9ei", ["task_deep", "report_deep"]
|
||||
)
|
||||
response = build_rpc_response("QA9ei", ["task_deep", "report_deep"])
|
||||
httpx_mock.add_response(content=response.encode())
|
||||
|
||||
async with NotebookLMClient(auth_tokens) as client:
|
||||
result = await client.research.start(
|
||||
"nb_123", "AI ethics", source="web", mode="deep"
|
||||
)
|
||||
result = await client.research.start("nb_123", "AI ethics", source="web", mode="deep")
|
||||
|
||||
assert result is not None
|
||||
assert result["mode"] == "deep"
|
||||
|
|
@ -90,9 +82,7 @@ class TestResearchAPI:
|
|||
"""Test that deep research on drive raises ValueError."""
|
||||
async with NotebookLMClient(auth_tokens) as client:
|
||||
with pytest.raises(ValueError, match="Deep Research only supports Web"):
|
||||
await client.research.start(
|
||||
"nb_123", "query", source="drive", mode="deep"
|
||||
)
|
||||
await client.research.start("nb_123", "query", source="drive", mode="deep")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_invalid_source_raises(
|
||||
|
|
@ -228,9 +218,7 @@ class TestResearchAPI:
|
|||
{"url": "https://example.com/quantum", "title": "Quantum Computing Guide"},
|
||||
{"url": "https://example.com/ai", "title": "AI Research Paper"},
|
||||
]
|
||||
result = await client.research.import_sources(
|
||||
"nb_123", "task_123", sources_to_import
|
||||
)
|
||||
result = await client.research.import_sources("nb_123", "task_123", sources_to_import)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == "src_001"
|
||||
|
|
|
|||
|
|
@ -53,9 +53,7 @@ class TestAddSource:
|
|||
httpx_mock.add_response(content=response.encode())
|
||||
|
||||
async with NotebookLMClient(auth_tokens) as client:
|
||||
source = await client.sources.add_text(
|
||||
"nb_123", "My Document", "This is the content"
|
||||
)
|
||||
source = await client.sources.add_text("nb_123", "My Document", "This is the content")
|
||||
|
||||
assert isinstance(source, Source)
|
||||
assert source.id == "source_id"
|
||||
|
|
@ -148,9 +146,37 @@ class TestSourcesAPI:
|
|||
[
|
||||
"Test Notebook",
|
||||
[
|
||||
[["src_001"], "My Article", [None, 11, [1704067200, 0], None, 5, None, None, ["https://example.com"]], [None, 2]],
|
||||
[
|
||||
["src_001"],
|
||||
"My Article",
|
||||
[
|
||||
None,
|
||||
11,
|
||||
[1704067200, 0],
|
||||
None,
|
||||
5,
|
||||
None,
|
||||
None,
|
||||
["https://example.com"],
|
||||
],
|
||||
[None, 2],
|
||||
],
|
||||
[["src_002"], "My Text", [None, 0, [1704153600, 0]], [None, 2]],
|
||||
[["src_003"], "YouTube Video", [None, 11, [1704240000, 0], None, 5, None, None, ["https://youtube.com/watch?v=abc"]], [None, 2]],
|
||||
[
|
||||
["src_003"],
|
||||
"YouTube Video",
|
||||
[
|
||||
None,
|
||||
11,
|
||||
[1704240000, 0],
|
||||
None,
|
||||
5,
|
||||
None,
|
||||
None,
|
||||
["https://youtube.com/watch?v=abc"],
|
||||
],
|
||||
[None, 2],
|
||||
],
|
||||
],
|
||||
"nb_123",
|
||||
"📘",
|
||||
|
|
@ -180,7 +206,16 @@ class TestSourcesAPI:
|
|||
"""Test listing sources from empty notebook."""
|
||||
response = build_rpc_response(
|
||||
RPCMethod.GET_NOTEBOOK,
|
||||
[["Empty Notebook", [], "nb_123", "📘", None, [None, None, None, None, None, [1704067200, 0]]]],
|
||||
[
|
||||
[
|
||||
"Empty Notebook",
|
||||
[],
|
||||
"nb_123",
|
||||
"📘",
|
||||
None,
|
||||
[None, None, None, None, None, [1704067200, 0]],
|
||||
]
|
||||
],
|
||||
)
|
||||
httpx_mock.add_response(content=response.encode())
|
||||
|
||||
|
|
@ -199,7 +234,16 @@ class TestSourcesAPI:
|
|||
"""Test getting a non-existent source."""
|
||||
response = build_rpc_response(
|
||||
RPCMethod.GET_NOTEBOOK,
|
||||
[["Notebook", [[["src_001"], "Source 1", [None, 0], [None, 2]]], "nb_123", "📘", None, [None, None, None, None, None, [1704067200, 0]]]],
|
||||
[
|
||||
[
|
||||
"Notebook",
|
||||
[[["src_001"], "Source 1", [None, 0], [None, 2]]],
|
||||
"nb_123",
|
||||
"📘",
|
||||
None,
|
||||
[None, None, None, None, None, [1704067200, 0]],
|
||||
]
|
||||
],
|
||||
)
|
||||
httpx_mock.add_response(content=response.encode())
|
||||
|
||||
|
|
@ -367,11 +411,7 @@ class TestAddFileSource:
|
|||
# Step 1: Mock RPC registration response (o4cbdc)
|
||||
rpc_response = build_rpc_response(
|
||||
RPCMethod.ADD_SOURCE_FILE,
|
||||
[
|
||||
[
|
||||
[["file_source_123"], "test_document.txt", [None, None, None, None, 0]]
|
||||
]
|
||||
],
|
||||
[[[["file_source_123"], "test_document.txt", [None, None, None, None, 0]]]],
|
||||
)
|
||||
httpx_mock.add_response(
|
||||
url=re.compile(r".*batchexecute.*"),
|
||||
|
|
@ -504,6 +544,7 @@ class TestAddFileSource:
|
|||
|
||||
# Verify body contains metadata
|
||||
import json
|
||||
|
||||
body = json.loads(start_request.content.decode())
|
||||
assert body["PROJECT_ID"] == "nb_123"
|
||||
assert body["SOURCE_NAME"] == "document.txt"
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
"""Shared fixtures for CLI unit tests."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -72,6 +73,24 @@ def create_mock_client():
|
|||
return mock_client
|
||||
|
||||
|
||||
def get_cli_module(module_path: str):
|
||||
"""Get the actual CLI module by path, bypassing shadowed names.
|
||||
|
||||
In cli/__init__.py, module names are shadowed by click groups with the same name
|
||||
(e.g., `from .source import source`). This function uses importlib to get the
|
||||
actual module for Python 3.10 compatibility.
|
||||
|
||||
Args:
|
||||
module_path: The module name within notebooklm.cli (e.g., "source", "skill")
|
||||
|
||||
Returns:
|
||||
The actual module object
|
||||
"""
|
||||
import importlib
|
||||
|
||||
return importlib.import_module(f"notebooklm.cli.{module_path}")
|
||||
|
||||
|
||||
def patch_client_for_module(module_path: str):
|
||||
"""Create a context manager that patches NotebookLMClient in the given module.
|
||||
|
||||
|
|
@ -86,8 +105,16 @@ def patch_client_for_module(module_path: str):
|
|||
mock_client = create_mock_client()
|
||||
mock_cls.return_value = mock_client
|
||||
# ... run test
|
||||
|
||||
Note:
|
||||
Uses importlib to get the actual module, not the click group that shadows
|
||||
the module name in cli/__init__.py. This is required for Python 3.10
|
||||
compatibility where mock.patch's string path resolution gets the wrong object.
|
||||
"""
|
||||
return patch(f"notebooklm.cli.{module_path}.NotebookLMClient")
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(f"notebooklm.cli.{module_path}")
|
||||
return patch.object(module, "NotebookLMClient")
|
||||
|
||||
|
||||
class MultiMockProxy:
|
||||
|
|
@ -96,15 +123,16 @@ class MultiMockProxy:
|
|||
When you set return_value on this proxy, it propagates to all mocks.
|
||||
Other attribute access is delegated to the primary mock.
|
||||
"""
|
||||
|
||||
def __init__(self, mocks):
|
||||
object.__setattr__(self, '_mocks', mocks)
|
||||
object.__setattr__(self, '_primary', mocks[0])
|
||||
object.__setattr__(self, "_mocks", mocks)
|
||||
object.__setattr__(self, "_primary", mocks[0])
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._primary, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name == 'return_value':
|
||||
if name == "return_value":
|
||||
# Propagate return_value to all mocks
|
||||
for m in self._mocks:
|
||||
m.return_value = value
|
||||
|
|
@ -118,6 +146,7 @@ class MultiPatcher:
|
|||
After refactoring, commands are spread across multiple modules, so we need
|
||||
to patch NotebookLMClient in all of them.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.patches = [
|
||||
patch("notebooklm.cli.notebook.NotebookLMClient"),
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
"""Tests for artifact CLI commands."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from notebooklm.notebooklm_cli import cli
|
||||
from notebooklm.types import Artifact, ReportSuggestion
|
||||
from notebooklm.types import Artifact
|
||||
|
||||
from .conftest import create_mock_client, patch_client_for_module
|
||||
|
||||
|
|
@ -84,9 +84,7 @@ class TestArtifactList:
|
|||
]
|
||||
)
|
||||
mock_client.notes.list_mind_maps = AsyncMock(return_value=[])
|
||||
mock_client.notebooks.get = AsyncMock(
|
||||
return_value=MagicMock(title="Test Notebook")
|
||||
)
|
||||
mock_client.notebooks.get = AsyncMock(return_value=MagicMock(title="Test Notebook"))
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
|
|
@ -120,7 +118,7 @@ class TestArtifactGet:
|
|||
title="Test Artifact",
|
||||
artifact_type=4,
|
||||
status=3,
|
||||
created_at=datetime(2024, 1, 1)
|
||||
created_at=datetime(2024, 1, 1),
|
||||
)
|
||||
)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
|
@ -161,9 +159,7 @@ class TestArtifactRename:
|
|||
mock_client = create_mock_client()
|
||||
# Mock list for partial ID resolution
|
||||
mock_client.artifacts.list = AsyncMock(
|
||||
return_value=[
|
||||
Artifact(id="art_123", title="Old Title", artifact_type=4, status=3)
|
||||
]
|
||||
return_value=[Artifact(id="art_123", title="Old Title", artifact_type=4, status=3)]
|
||||
)
|
||||
mock_client.notes.list_mind_maps = AsyncMock(return_value=[])
|
||||
mock_client.artifacts.rename = AsyncMock(
|
||||
|
|
@ -185,9 +181,7 @@ class TestArtifactRename:
|
|||
mock_client = create_mock_client()
|
||||
# Mock list for partial ID resolution (include the mind map)
|
||||
mock_client.artifacts.list = AsyncMock(
|
||||
return_value=[
|
||||
Artifact(id="mm_123", title="Old Title", artifact_type=5, status=3)
|
||||
]
|
||||
return_value=[Artifact(id="mm_123", title="Old Title", artifact_type=5, status=3)]
|
||||
)
|
||||
mock_client.notes.list_mind_maps = AsyncMock(
|
||||
return_value=[
|
||||
|
|
@ -227,9 +221,7 @@ class TestArtifactDelete:
|
|||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
cli, ["artifact", "delete", "art_123", "-n", "nb_123", "-y"]
|
||||
)
|
||||
result = runner.invoke(cli, ["artifact", "delete", "art_123", "-n", "nb_123", "-y"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Deleted artifact" in result.output
|
||||
|
|
@ -253,9 +245,7 @@ class TestArtifactDelete:
|
|||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
cli, ["artifact", "delete", "mm_456", "-n", "nb_123", "-y"]
|
||||
)
|
||||
result = runner.invoke(cli, ["artifact", "delete", "mm_456", "-n", "nb_123", "-y"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Cleared mind map" in result.output
|
||||
|
|
@ -273,9 +263,7 @@ class TestArtifactExport:
|
|||
mock_client = create_mock_client()
|
||||
# Mock list for partial ID resolution
|
||||
mock_client.artifacts.list = AsyncMock(
|
||||
return_value=[
|
||||
Artifact(id="art_123", title="Doc", artifact_type=2, status=3)
|
||||
]
|
||||
return_value=[Artifact(id="art_123", title="Doc", artifact_type=2, status=3)]
|
||||
)
|
||||
mock_client.artifacts.export = AsyncMock(
|
||||
return_value={"url": "https://docs.google.com/document/d/123"}
|
||||
|
|
@ -294,6 +282,7 @@ class TestArtifactExport:
|
|||
mock_client.artifacts.export.assert_called_once()
|
||||
call_args = mock_client.artifacts.export.call_args
|
||||
from notebooklm.rpc import ExportType
|
||||
|
||||
# call_args[0] = (notebook_id, artifact_id, content, title, export_type)
|
||||
assert call_args[0][2] is None, "content should be None (backend retrieves it)"
|
||||
assert call_args[0][4] == ExportType.DOCS, "export_type should be ExportType.DOCS"
|
||||
|
|
@ -303,9 +292,7 @@ class TestArtifactExport:
|
|||
mock_client = create_mock_client()
|
||||
# Mock list for partial ID resolution
|
||||
mock_client.artifacts.list = AsyncMock(
|
||||
return_value=[
|
||||
Artifact(id="art_123", title="Table", artifact_type=9, status=3)
|
||||
]
|
||||
return_value=[Artifact(id="art_123", title="Table", artifact_type=9, status=3)]
|
||||
)
|
||||
mock_client.artifacts.export = AsyncMock(
|
||||
return_value={"url": "https://sheets.google.com/spreadsheets/d/123"}
|
||||
|
|
@ -315,7 +302,18 @@ class TestArtifactExport:
|
|||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
cli, ["artifact", "export", "art_123", "--title", "My Sheet", "--type", "sheets", "-n", "nb_123"]
|
||||
cli,
|
||||
[
|
||||
"artifact",
|
||||
"export",
|
||||
"art_123",
|
||||
"--title",
|
||||
"My Sheet",
|
||||
"--type",
|
||||
"sheets",
|
||||
"-n",
|
||||
"nb_123",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
|
@ -324,6 +322,7 @@ class TestArtifactExport:
|
|||
mock_client.artifacts.export.assert_called_once()
|
||||
call_args = mock_client.artifacts.export.call_args
|
||||
from notebooklm.rpc import ExportType
|
||||
|
||||
# call_args[0] = (notebook_id, artifact_id, content, title, export_type)
|
||||
assert call_args[0][2] is None, "content should be None (backend retrieves it)"
|
||||
assert call_args[0][4] == ExportType.SHEETS, "export_type should be ExportType.SHEETS"
|
||||
|
|
@ -333,9 +332,7 @@ class TestArtifactExport:
|
|||
mock_client = create_mock_client()
|
||||
# Mock list for partial ID resolution
|
||||
mock_client.artifacts.list = AsyncMock(
|
||||
return_value=[
|
||||
Artifact(id="art_123", title="Doc", artifact_type=2, status=3)
|
||||
]
|
||||
return_value=[Artifact(id="art_123", title="Doc", artifact_type=2, status=3)]
|
||||
)
|
||||
mock_client.artifacts.export = AsyncMock(return_value=None)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
|
@ -384,24 +381,18 @@ class TestArtifactWait:
|
|||
mock_client = create_mock_client()
|
||||
# Mock list for partial ID resolution
|
||||
mock_client.artifacts.list = AsyncMock(
|
||||
return_value=[
|
||||
Artifact(id="art_123", title="Test", artifact_type=1, status=3)
|
||||
]
|
||||
return_value=[Artifact(id="art_123", title="Test", artifact_type=1, status=3)]
|
||||
)
|
||||
mock_client.artifacts.wait_for_completion = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
status="completed",
|
||||
url="https://example.com/audio.mp3",
|
||||
error=None
|
||||
status="completed", url="https://example.com/audio.mp3", error=None
|
||||
)
|
||||
)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
cli, ["artifact", "wait", "art_123", "-n", "nb_123"]
|
||||
)
|
||||
result = runner.invoke(cli, ["artifact", "wait", "art_123", "-n", "nb_123"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Artifact completed" in result.output
|
||||
|
|
@ -411,24 +402,18 @@ class TestArtifactWait:
|
|||
with patch_client_for_module("artifact") as mock_client_cls:
|
||||
mock_client = create_mock_client()
|
||||
mock_client.artifacts.list = AsyncMock(
|
||||
return_value=[
|
||||
Artifact(id="art_123", title="Test", artifact_type=1, status=1)
|
||||
]
|
||||
return_value=[Artifact(id="art_123", title="Test", artifact_type=1, status=1)]
|
||||
)
|
||||
mock_client.artifacts.wait_for_completion = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
status="failed",
|
||||
url=None,
|
||||
error="Generation failed due to content policy"
|
||||
status="failed", url=None, error="Generation failed due to content policy"
|
||||
)
|
||||
)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
cli, ["artifact", "wait", "art_123", "-n", "nb_123"]
|
||||
)
|
||||
result = runner.invoke(cli, ["artifact", "wait", "art_123", "-n", "nb_123"])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Generation failed" in result.output
|
||||
|
|
@ -438,9 +423,7 @@ class TestArtifactWait:
|
|||
with patch_client_for_module("artifact") as mock_client_cls:
|
||||
mock_client = create_mock_client()
|
||||
mock_client.artifacts.list = AsyncMock(
|
||||
return_value=[
|
||||
Artifact(id="art_123", title="Test", artifact_type=1, status=1)
|
||||
]
|
||||
return_value=[Artifact(id="art_123", title="Test", artifact_type=1, status=1)]
|
||||
)
|
||||
mock_client.artifacts.wait_for_completion = AsyncMock(
|
||||
side_effect=TimeoutError("Timed out")
|
||||
|
|
@ -461,15 +444,11 @@ class TestArtifactWait:
|
|||
with patch_client_for_module("artifact") as mock_client_cls:
|
||||
mock_client = create_mock_client()
|
||||
mock_client.artifacts.list = AsyncMock(
|
||||
return_value=[
|
||||
Artifact(id="art_123", title="Test", artifact_type=1, status=3)
|
||||
]
|
||||
return_value=[Artifact(id="art_123", title="Test", artifact_type=1, status=3)]
|
||||
)
|
||||
mock_client.artifacts.wait_for_completion = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
status="completed",
|
||||
url="https://example.com/audio.mp3",
|
||||
error=None
|
||||
status="completed", url="https://example.com/audio.mp3", error=None
|
||||
)
|
||||
)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
|
@ -490,9 +469,7 @@ class TestArtifactWait:
|
|||
with patch_client_for_module("artifact") as mock_client_cls:
|
||||
mock_client = create_mock_client()
|
||||
mock_client.artifacts.list = AsyncMock(
|
||||
return_value=[
|
||||
Artifact(id="art_123", title="Test", artifact_type=1, status=1)
|
||||
]
|
||||
return_value=[Artifact(id="art_123", title="Test", artifact_type=1, status=1)]
|
||||
)
|
||||
mock_client.artifacts.wait_for_completion = AsyncMock(
|
||||
side_effect=TimeoutError("Timed out")
|
||||
|
|
|
|||
|
|
@ -1,19 +1,24 @@
|
|||
"""Tests for download CLI commands."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from notebooklm.notebooklm_cli import cli
|
||||
from notebooklm.types import Artifact
|
||||
|
||||
from .conftest import create_mock_client, patch_client_for_module
|
||||
from .conftest import create_mock_client, get_cli_module, patch_client_for_module
|
||||
|
||||
# Get the actual download module (not the click group that shadows it)
|
||||
download_module = get_cli_module("download")
|
||||
|
||||
|
||||
def make_artifact(id: str, title: str, artifact_type: int, status: int = 3, created_at: datetime = None) -> Artifact:
|
||||
def make_artifact(
|
||||
id: str, title: str, artifact_type: int, status: int = 3, created_at: datetime = None
|
||||
) -> Artifact:
|
||||
"""Create an Artifact for testing."""
|
||||
return Artifact(
|
||||
id=id,
|
||||
|
|
@ -65,8 +70,10 @@ class TestDownloadAudio:
|
|||
mock_client.artifacts.download_audio = mock_download_audio
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
|
||||
with patch.object(
|
||||
download_module, "fetch_tokens", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
with patch.object(download_module, "load_auth_from_storage") as mock_load:
|
||||
mock_load.return_value = {"SID": "test", "HSID": "test", "SSID": "test"}
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
|
|
@ -84,13 +91,13 @@ class TestDownloadAudio:
|
|||
)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
|
||||
with patch.object(
|
||||
download_module, "fetch_tokens", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
with patch.object(download_module, "load_auth_from_storage") as mock_load:
|
||||
mock_load.return_value = {"SID": "test", "HSID": "test", "SSID": "test"}
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
cli, ["download", "audio", "--dry-run", "-n", "nb_123"]
|
||||
)
|
||||
result = runner.invoke(cli, ["download", "audio", "--dry-run", "-n", "nb_123"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "DRY RUN" in result.output
|
||||
|
|
@ -101,8 +108,10 @@ class TestDownloadAudio:
|
|||
mock_client.artifacts.list = AsyncMock(return_value=[])
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
|
||||
with patch.object(
|
||||
download_module, "fetch_tokens", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
with patch.object(download_module, "load_auth_from_storage") as mock_load:
|
||||
mock_load.return_value = {"SID": "test", "HSID": "test", "SSID": "test"}
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(cli, ["download", "audio", "-n", "nb_123"])
|
||||
|
|
@ -133,8 +142,10 @@ class TestDownloadVideo:
|
|||
mock_client.artifacts.download_video = mock_download_video
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
|
||||
with patch.object(
|
||||
download_module, "fetch_tokens", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
with patch.object(download_module, "load_auth_from_storage") as mock_load:
|
||||
mock_load.return_value = {"SID": "test", "HSID": "test", "SSID": "test"}
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
|
|
@ -168,8 +179,10 @@ class TestDownloadInfographic:
|
|||
mock_client.artifacts.download_infographic = mock_download_infographic
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
|
||||
with patch.object(
|
||||
download_module, "fetch_tokens", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
with patch.object(download_module, "load_auth_from_storage") as mock_load:
|
||||
mock_load.return_value = {"SID": "test", "HSID": "test", "SSID": "test"}
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
|
|
@ -204,8 +217,10 @@ class TestDownloadSlideDeck:
|
|||
mock_client.artifacts.download_slide_deck = mock_download_slide_deck
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
|
||||
with patch.object(
|
||||
download_module, "fetch_tokens", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
with patch.object(download_module, "load_auth_from_storage") as mock_load:
|
||||
mock_load.return_value = {"SID": "test", "HSID": "test", "SSID": "test"}
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
|
|
@ -235,15 +250,21 @@ class TestDownloadFlags:
|
|||
# Set up artifacts namespace (pre-created by create_mock_client)
|
||||
mock_client.artifacts.list = AsyncMock(
|
||||
return_value=[
|
||||
make_artifact("audio_old", "Old Audio", 1, created_at=datetime.fromtimestamp(1000000000)),
|
||||
make_artifact("audio_new", "New Audio", 1, created_at=datetime.fromtimestamp(2000000000)),
|
||||
make_artifact(
|
||||
"audio_old", "Old Audio", 1, created_at=datetime.fromtimestamp(1000000000)
|
||||
),
|
||||
make_artifact(
|
||||
"audio_new", "New Audio", 1, created_at=datetime.fromtimestamp(2000000000)
|
||||
),
|
||||
]
|
||||
)
|
||||
mock_client.artifacts.download_audio = mock_download_audio
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
|
||||
with patch.object(
|
||||
download_module, "fetch_tokens", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
with patch.object(download_module, "load_auth_from_storage") as mock_load:
|
||||
mock_load.return_value = {"SID": "test"}
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
|
|
@ -266,15 +287,21 @@ class TestDownloadFlags:
|
|||
# Set up artifacts namespace (pre-created by create_mock_client)
|
||||
mock_client.artifacts.list = AsyncMock(
|
||||
return_value=[
|
||||
make_artifact("audio_old", "Old Audio", 1, created_at=datetime.fromtimestamp(1000000000)),
|
||||
make_artifact("audio_new", "New Audio", 1, created_at=datetime.fromtimestamp(2000000000)),
|
||||
make_artifact(
|
||||
"audio_old", "Old Audio", 1, created_at=datetime.fromtimestamp(1000000000)
|
||||
),
|
||||
make_artifact(
|
||||
"audio_new", "New Audio", 1, created_at=datetime.fromtimestamp(2000000000)
|
||||
),
|
||||
]
|
||||
)
|
||||
mock_client.artifacts.download_audio = mock_download_audio
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
|
||||
with patch.object(
|
||||
download_module, "fetch_tokens", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
with patch.object(download_module, "load_auth_from_storage") as mock_load:
|
||||
mock_load.return_value = {"SID": "test"}
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
|
|
@ -302,8 +329,10 @@ class TestDownloadFlags:
|
|||
mock_client.artifacts.download_audio = mock_download_audio
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
|
||||
with patch.object(
|
||||
download_module, "fetch_tokens", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
with patch.object(download_module, "load_auth_from_storage") as mock_load:
|
||||
mock_load.return_value = {"SID": "test"}
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
|
|
@ -326,8 +355,10 @@ class TestDownloadFlags:
|
|||
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.download.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
with patch("notebooklm.cli.download.load_auth_from_storage") as mock_load:
|
||||
with patch.object(
|
||||
download_module, "fetch_tokens", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
with patch.object(download_module, "load_auth_from_storage") as mock_load:
|
||||
mock_load.return_value = {"SID": "test"}
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
"""Tests for generate CLI commands."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from notebooklm.notebooklm_cli import cli
|
||||
|
|
@ -60,7 +60,9 @@ class TestGenerateAudio:
|
|||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(cli, ["generate", "audio", "--format", "debate", "-n", "nb_123"])
|
||||
result = runner.invoke(
|
||||
cli, ["generate", "audio", "--format", "debate", "-n", "nb_123"]
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_client.artifacts.generate_audio.assert_called()
|
||||
|
|
@ -75,7 +77,9 @@ class TestGenerateAudio:
|
|||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(cli, ["generate", "audio", "--length", "long", "-n", "nb_123"])
|
||||
result = runner.invoke(
|
||||
cli, ["generate", "audio", "--length", "long", "-n", "nb_123"]
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
|
@ -160,7 +164,9 @@ class TestGenerateVideo:
|
|||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(cli, ["generate", "video", "--style", "kawaii", "-n", "nb_123"])
|
||||
result = runner.invoke(
|
||||
cli, ["generate", "video", "--style", "kawaii", "-n", "nb_123"]
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
|
@ -196,7 +202,17 @@ class TestGenerateQuiz:
|
|||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
cli, ["generate", "quiz", "--quantity", "more", "--difficulty", "hard", "-n", "nb_123"]
|
||||
cli,
|
||||
[
|
||||
"generate",
|
||||
"quiz",
|
||||
"--quantity",
|
||||
"more",
|
||||
"--difficulty",
|
||||
"hard",
|
||||
"-n",
|
||||
"nb_123",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
|
@ -254,7 +270,17 @@ class TestGenerateSlideDeck:
|
|||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
cli, ["generate", "slide-deck", "--format", "presenter", "--length", "short", "-n", "nb_123"]
|
||||
cli,
|
||||
[
|
||||
"generate",
|
||||
"slide-deck",
|
||||
"--format",
|
||||
"presenter",
|
||||
"--length",
|
||||
"short",
|
||||
"-n",
|
||||
"nb_123",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
|
@ -291,7 +317,17 @@ class TestGenerateInfographic:
|
|||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
cli, ["generate", "infographic", "--orientation", "portrait", "--detail", "detailed", "-n", "nb_123"]
|
||||
cli,
|
||||
[
|
||||
"generate",
|
||||
"infographic",
|
||||
"--orientation",
|
||||
"portrait",
|
||||
"--detail",
|
||||
"detailed",
|
||||
"-n",
|
||||
"nb_123",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
|
|
|||
|
|
@ -88,9 +88,15 @@ class TestSectionedHelp:
|
|||
# (it may still appear if Click adds it, but our sections should dominate)
|
||||
lines = result.output.split("\n")
|
||||
# Count section headers
|
||||
section_count = sum(1 for line in lines if line.strip().endswith(":") and
|
||||
any(s in line for s in ["Session", "Notebooks", "Chat",
|
||||
"Command Groups", "Artifact Actions"]))
|
||||
section_count = sum(
|
||||
1
|
||||
for line in lines
|
||||
if line.strip().endswith(":")
|
||||
and any(
|
||||
s in line
|
||||
for s in ["Session", "Notebooks", "Chat", "Command Groups", "Artifact Actions"]
|
||||
)
|
||||
)
|
||||
assert section_count >= 4 # At least 4 of our sections should appear (no Insights anymore)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,39 +1,37 @@
|
|||
"""Tests for CLI helper functions."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from io import StringIO
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
from notebooklm.cli.helpers import (
|
||||
ARTIFACT_TYPE_MAP,
|
||||
clear_context,
|
||||
detect_source_type,
|
||||
# Type display helpers
|
||||
get_artifact_type_display,
|
||||
get_source_type_display,
|
||||
detect_source_type,
|
||||
ARTIFACT_TYPE_DISPLAY,
|
||||
ARTIFACT_TYPE_MAP,
|
||||
# Output helpers
|
||||
json_output_response,
|
||||
json_error_response,
|
||||
# Context helpers
|
||||
get_current_notebook,
|
||||
set_current_notebook,
|
||||
clear_context,
|
||||
get_current_conversation,
|
||||
set_current_conversation,
|
||||
require_notebook,
|
||||
# Error handling
|
||||
handle_error,
|
||||
handle_auth_error,
|
||||
get_auth_tokens,
|
||||
# Auth helpers
|
||||
get_client,
|
||||
get_auth_tokens,
|
||||
get_current_conversation,
|
||||
# Context helpers
|
||||
get_current_notebook,
|
||||
get_source_type_display,
|
||||
handle_auth_error,
|
||||
# Error handling
|
||||
handle_error,
|
||||
json_error_response,
|
||||
# Output helpers
|
||||
json_output_response,
|
||||
require_notebook,
|
||||
run_async,
|
||||
set_current_conversation,
|
||||
set_current_notebook,
|
||||
# Decorator
|
||||
with_client,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ARTIFACT TYPE DISPLAY TESTS
|
||||
# =============================================================================
|
||||
|
|
@ -92,15 +90,27 @@ class TestGetArtifactTypeDisplay:
|
|||
|
||||
class TestDetectSourceType:
|
||||
def test_youtube_url(self):
|
||||
src = ["id", "Video Title", [None, None, None, None, None, None, None, ["https://youtube.com/watch?v=abc"]]]
|
||||
src = [
|
||||
"id",
|
||||
"Video Title",
|
||||
[None, None, None, None, None, None, None, ["https://youtube.com/watch?v=abc"]],
|
||||
]
|
||||
assert detect_source_type(src) == "🎥 YouTube"
|
||||
|
||||
def test_youtu_be_url(self):
|
||||
src = ["id", "Video Title", [None, None, None, None, None, None, None, ["https://youtu.be/abc"]]]
|
||||
src = [
|
||||
"id",
|
||||
"Video Title",
|
||||
[None, None, None, None, None, None, None, ["https://youtu.be/abc"]],
|
||||
]
|
||||
assert detect_source_type(src) == "🎥 YouTube"
|
||||
|
||||
def test_web_url(self):
|
||||
src = ["id", "Web Page", [None, None, None, None, None, None, None, ["https://example.com/article"]]]
|
||||
src = [
|
||||
"id",
|
||||
"Web Page",
|
||||
[None, None, None, None, None, None, None, ["https://example.com/article"]],
|
||||
]
|
||||
assert detect_source_type(src) == "🔗 Web URL"
|
||||
|
||||
def test_pdf_file(self):
|
||||
|
|
@ -256,7 +266,9 @@ class TestJsonErrorResponse:
|
|||
|
||||
class TestContextManagement:
|
||||
def test_get_current_notebook_no_file(self, tmp_path):
|
||||
with patch("notebooklm.cli.helpers.get_context_path", return_value=tmp_path / "nonexistent.json"):
|
||||
with patch(
|
||||
"notebooklm.cli.helpers.get_context_path", return_value=tmp_path / "nonexistent.json"
|
||||
):
|
||||
result = get_current_notebook()
|
||||
assert result is None
|
||||
|
||||
|
|
@ -271,10 +283,7 @@ class TestContextManagement:
|
|||
context_file = tmp_path / "context.json"
|
||||
with patch("notebooklm.cli.helpers.get_context_path", return_value=context_file):
|
||||
set_current_notebook(
|
||||
"nb_test123",
|
||||
title="Test Notebook",
|
||||
is_owner=True,
|
||||
created_at="2024-01-01T00:00:00"
|
||||
"nb_test123", title="Test Notebook", is_owner=True, created_at="2024-01-01T00:00:00"
|
||||
)
|
||||
data = json.loads(context_file.read_text())
|
||||
assert data["notebook_id"] == "nb_test123"
|
||||
|
|
@ -296,7 +305,9 @@ class TestContextManagement:
|
|||
clear_context() # Should not raise
|
||||
|
||||
def test_get_current_conversation_no_file(self, tmp_path):
|
||||
with patch("notebooklm.cli.helpers.get_context_path", return_value=tmp_path / "nonexistent.json"):
|
||||
with patch(
|
||||
"notebooklm.cli.helpers.get_context_path", return_value=tmp_path / "nonexistent.json"
|
||||
):
|
||||
result = get_current_conversation()
|
||||
assert result is None
|
||||
|
||||
|
|
@ -327,7 +338,9 @@ class TestContextManagement:
|
|||
|
||||
class TestRequireNotebook:
|
||||
def test_returns_provided_notebook_id(self, tmp_path):
|
||||
with patch("notebooklm.cli.helpers.get_context_path", return_value=tmp_path / "context.json"):
|
||||
with patch(
|
||||
"notebooklm.cli.helpers.get_context_path", return_value=tmp_path / "context.json"
|
||||
):
|
||||
result = require_notebook("nb_provided")
|
||||
assert result == "nb_provided"
|
||||
|
||||
|
|
@ -339,7 +352,9 @@ class TestRequireNotebook:
|
|||
assert result == "nb_context"
|
||||
|
||||
def test_raises_system_exit_when_no_notebook(self, tmp_path):
|
||||
with patch("notebooklm.cli.helpers.get_context_path", return_value=tmp_path / "nonexistent.json"):
|
||||
with patch(
|
||||
"notebooklm.cli.helpers.get_context_path", return_value=tmp_path / "nonexistent.json"
|
||||
):
|
||||
with patch("notebooklm.cli.helpers.console") as mock_console:
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
require_notebook(None)
|
||||
|
|
@ -395,14 +410,15 @@ class TestHandleAuthError:
|
|||
class TestWithClientDecorator:
|
||||
def test_decorator_passes_auth_to_function(self):
|
||||
"""Test that @with_client properly injects client_auth"""
|
||||
from click.testing import CliRunner
|
||||
import click
|
||||
from click.testing import CliRunner
|
||||
|
||||
@click.command()
|
||||
@with_client
|
||||
def test_cmd(ctx, client_auth):
|
||||
async def _run():
|
||||
click.echo(f"Got auth: {client_auth is not None}")
|
||||
|
||||
return _run()
|
||||
|
||||
runner = CliRunner()
|
||||
|
|
@ -417,14 +433,15 @@ class TestWithClientDecorator:
|
|||
|
||||
def test_decorator_handles_no_auth(self):
|
||||
"""Test that @with_client handles missing auth gracefully"""
|
||||
from click.testing import CliRunner
|
||||
import click
|
||||
from click.testing import CliRunner
|
||||
|
||||
@click.command()
|
||||
@with_client
|
||||
def test_cmd(ctx, client_auth):
|
||||
async def _run():
|
||||
pass
|
||||
|
||||
return _run()
|
||||
|
||||
runner = CliRunner()
|
||||
|
|
@ -437,14 +454,15 @@ class TestWithClientDecorator:
|
|||
|
||||
def test_decorator_handles_exception_non_json(self):
|
||||
"""Test error handling in non-JSON mode"""
|
||||
from click.testing import CliRunner
|
||||
import click
|
||||
from click.testing import CliRunner
|
||||
|
||||
@click.command()
|
||||
@with_client
|
||||
def test_cmd(ctx, client_auth):
|
||||
async def _run():
|
||||
raise ValueError("Test error")
|
||||
|
||||
return _run()
|
||||
|
||||
runner = CliRunner()
|
||||
|
|
@ -459,8 +477,8 @@ class TestWithClientDecorator:
|
|||
|
||||
def test_decorator_handles_exception_json_mode(self):
|
||||
"""Test error handling in JSON mode"""
|
||||
from click.testing import CliRunner
|
||||
import click
|
||||
from click.testing import CliRunner
|
||||
|
||||
@click.command()
|
||||
@click.option("--json", "json_output", is_flag=True)
|
||||
|
|
@ -468,6 +486,7 @@ class TestWithClientDecorator:
|
|||
def test_cmd(ctx, json_output, client_auth):
|
||||
async def _run():
|
||||
raise ValueError("Test error")
|
||||
|
||||
return _run()
|
||||
|
||||
runner = CliRunner()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"""Tests for note CLI commands."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from notebooklm.notebooklm_cli import cli
|
||||
|
|
@ -267,9 +267,7 @@ class TestNoteDelete:
|
|||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
cli, ["note", "delete", "note_123", "-n", "nb_123", "-y"]
|
||||
)
|
||||
result = runner.invoke(cli, ["note", "delete", "note_123", "-n", "nb_123", "-y"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Deleted note" in result.output
|
||||
|
|
|
|||
|
|
@ -1,17 +1,17 @@
|
|||
"""Tests for notebook CLI commands (now top-level commands)."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from notebooklm.notebooklm_cli import cli
|
||||
from notebooklm.types import Notebook, NotebookDescription, SuggestedTopic, AskResult
|
||||
from notebooklm.types import AskResult, Notebook
|
||||
|
||||
from .conftest import create_mock_client, patch_main_cli_client, patch_client_for_module
|
||||
from .conftest import create_mock_client, patch_client_for_module, patch_main_cli_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -56,8 +56,18 @@ class TestNotebookList:
|
|||
mock_client = create_mock_client()
|
||||
mock_client.notebooks.list = AsyncMock(
|
||||
return_value=[
|
||||
Notebook(id="nb_1", title="First Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
|
||||
Notebook(id="nb_2", title="Second Notebook", created_at=datetime(2024, 1, 2), is_owner=False),
|
||||
Notebook(
|
||||
id="nb_1",
|
||||
title="First Notebook",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
is_owner=True,
|
||||
),
|
||||
Notebook(
|
||||
id="nb_2",
|
||||
title="Second Notebook",
|
||||
created_at=datetime(2024, 1, 2),
|
||||
is_owner=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
|
@ -75,7 +85,12 @@ class TestNotebookList:
|
|||
mock_client = create_mock_client()
|
||||
mock_client.notebooks.list = AsyncMock(
|
||||
return_value=[
|
||||
Notebook(id="nb_1", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
|
||||
Notebook(
|
||||
id="nb_1",
|
||||
title="Test Notebook",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
is_owner=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
|
@ -101,7 +116,9 @@ class TestNotebookCreate:
|
|||
with patch_main_cli_client() as mock_client_cls:
|
||||
mock_client = create_mock_client()
|
||||
mock_client.notebooks.create = AsyncMock(
|
||||
return_value=Notebook(id="new_nb_id", title="Test Notebook", created_at=datetime(2024, 1, 1))
|
||||
return_value=Notebook(
|
||||
id="new_nb_id", title="Test Notebook", created_at=datetime(2024, 1, 1)
|
||||
)
|
||||
)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
|
|
@ -116,7 +133,9 @@ class TestNotebookCreate:
|
|||
with patch_main_cli_client() as mock_client_cls:
|
||||
mock_client = create_mock_client()
|
||||
mock_client.notebooks.create = AsyncMock(
|
||||
return_value=Notebook(id="new_nb_id", title="Test Notebook", created_at=datetime(2024, 1, 1))
|
||||
return_value=Notebook(
|
||||
id="new_nb_id", title="Test Notebook", created_at=datetime(2024, 1, 1)
|
||||
)
|
||||
)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
|
|
@ -141,7 +160,12 @@ class TestNotebookDelete:
|
|||
# Mock list for partial ID resolution (returns the notebook to be deleted)
|
||||
mock_client.notebooks.list = AsyncMock(
|
||||
return_value=[
|
||||
Notebook(id="nb_to_delete", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
|
||||
Notebook(
|
||||
id="nb_to_delete",
|
||||
title="Test Notebook",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
is_owner=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
mock_client.notebooks.delete = AsyncMock(return_value=True)
|
||||
|
|
@ -164,16 +188,25 @@ class TestNotebookDelete:
|
|||
# Mock list for partial ID resolution
|
||||
mock_client.notebooks.list = AsyncMock(
|
||||
return_value=[
|
||||
Notebook(id="nb_to_delete", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
|
||||
Notebook(
|
||||
id="nb_to_delete",
|
||||
title="Test Notebook",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
is_owner=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
mock_client.notebooks.delete = AsyncMock(return_value=True)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.helpers.get_context_path", return_value=context_file):
|
||||
with patch("notebooklm.cli.notebook.get_current_notebook", return_value="nb_to_delete"):
|
||||
with patch(
|
||||
"notebooklm.cli.notebook.get_current_notebook", return_value="nb_to_delete"
|
||||
):
|
||||
with patch("notebooklm.cli.notebook.clear_context") as mock_clear:
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
with patch(
|
||||
"notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(cli, ["delete", "-n", "nb_to_delete", "-y"])
|
||||
|
||||
|
|
@ -186,7 +219,12 @@ class TestNotebookDelete:
|
|||
# Mock list for partial ID resolution
|
||||
mock_client.notebooks.list = AsyncMock(
|
||||
return_value=[
|
||||
Notebook(id="nb_123", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
|
||||
Notebook(
|
||||
id="nb_123",
|
||||
title="Test Notebook",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
is_owner=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
mock_client.notebooks.delete = AsyncMock(return_value=False)
|
||||
|
|
@ -212,7 +250,12 @@ class TestNotebookRename:
|
|||
# Mock list for partial ID resolution
|
||||
mock_client.notebooks.list = AsyncMock(
|
||||
return_value=[
|
||||
Notebook(id="nb_123", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
|
||||
Notebook(
|
||||
id="nb_123",
|
||||
title="Test Notebook",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
is_owner=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
mock_client.notebooks.rename = AsyncMock(return_value=None)
|
||||
|
|
@ -239,7 +282,12 @@ class TestNotebookShare:
|
|||
# Mock list for partial ID resolution
|
||||
mock_client.notebooks.list = AsyncMock(
|
||||
return_value=[
|
||||
Notebook(id="nb_123", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
|
||||
Notebook(
|
||||
id="nb_123",
|
||||
title="Test Notebook",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
is_owner=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
mock_client.notebooks.share = AsyncMock(
|
||||
|
|
@ -266,7 +314,12 @@ class TestNotebookShare:
|
|||
# Mock list for partial ID resolution
|
||||
mock_client.notebooks.list = AsyncMock(
|
||||
return_value=[
|
||||
Notebook(id="nb_123", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
|
||||
Notebook(
|
||||
id="nb_123",
|
||||
title="Test Notebook",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
is_owner=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
mock_client.notebooks.share = AsyncMock(
|
||||
|
|
@ -299,7 +352,12 @@ class TestNotebookSummary:
|
|||
# Mock list for partial ID resolution
|
||||
mock_client.notebooks.list = AsyncMock(
|
||||
return_value=[
|
||||
Notebook(id="nb_123", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
|
||||
Notebook(
|
||||
id="nb_123",
|
||||
title="Test Notebook",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
is_owner=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
mock_desc = MagicMock()
|
||||
|
|
@ -322,7 +380,12 @@ class TestNotebookSummary:
|
|||
# Mock list for partial ID resolution
|
||||
mock_client.notebooks.list = AsyncMock(
|
||||
return_value=[
|
||||
Notebook(id="nb_123", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
|
||||
Notebook(
|
||||
id="nb_123",
|
||||
title="Test Notebook",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
is_owner=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
mock_desc = MagicMock()
|
||||
|
|
@ -347,7 +410,12 @@ class TestNotebookSummary:
|
|||
# Mock list for partial ID resolution
|
||||
mock_client.notebooks.list = AsyncMock(
|
||||
return_value=[
|
||||
Notebook(id="nb_123", title="Test Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
|
||||
Notebook(
|
||||
id="nb_123",
|
||||
title="Test Notebook",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
is_owner=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
mock_client.notebooks.get_description = AsyncMock(return_value=None)
|
||||
|
|
@ -370,9 +438,7 @@ class TestNotebookHistory:
|
|||
def test_notebook_history(self, runner, mock_auth):
|
||||
with patch_main_cli_client() as mock_client_cls:
|
||||
mock_client = create_mock_client()
|
||||
mock_client.chat.get_history = AsyncMock(
|
||||
return_value=[[["conv_1"], ["conv_2"]]]
|
||||
)
|
||||
mock_client.chat.get_history = AsyncMock(return_value=[[["conv_1"], ["conv_2"]]])
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
|
|
@ -429,8 +495,13 @@ class TestNotebookAsk:
|
|||
mock_client.chat.get_history = AsyncMock(return_value=None)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.helpers.get_context_path", return_value=Path("/nonexistent/context.json")):
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
with patch(
|
||||
"notebooklm.cli.helpers.get_context_path",
|
||||
return_value=Path("/nonexistent/context.json"),
|
||||
):
|
||||
with patch(
|
||||
"notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(cli, ["ask", "-n", "nb_123", "What is this?"])
|
||||
|
||||
|
|
@ -455,7 +526,9 @@ class TestNotebookAsk:
|
|||
result = runner.invoke(cli, ["ask", "-n", "nb_123", "--new", "Fresh question"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Starting new conversation" in result.output or "New conversation" in result.output
|
||||
assert (
|
||||
"Starting new conversation" in result.output or "New conversation" in result.output
|
||||
)
|
||||
|
||||
def test_notebook_ask_continue_conversation(self, runner, mock_auth):
|
||||
with patch_main_cli_client() as mock_client_cls:
|
||||
|
|
@ -492,7 +565,9 @@ class TestNotebookConfigure:
|
|||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(cli, ["configure", "-n", "nb_123", "--mode", "learning-guide"])
|
||||
result = runner.invoke(
|
||||
cli, ["configure", "-n", "nb_123", "--mode", "learning-guide"]
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Chat mode set to: learning-guide" in result.output
|
||||
|
|
@ -505,7 +580,9 @@ class TestNotebookConfigure:
|
|||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(cli, ["configure", "-n", "nb_123", "--persona", "Act as a tutor"])
|
||||
result = runner.invoke(
|
||||
cli, ["configure", "-n", "nb_123", "--persona", "Act as a tutor"]
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Chat configured" in result.output
|
||||
|
|
@ -519,7 +596,9 @@ class TestNotebookConfigure:
|
|||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(cli, ["configure", "-n", "nb_123", "--response-length", "longer"])
|
||||
result = runner.invoke(
|
||||
cli, ["configure", "-n", "nb_123", "--response-length", "longer"]
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "response length: longer" in result.output
|
||||
|
|
@ -542,7 +621,9 @@ class TestSourceAddResearch:
|
|||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(cli, ["source", "add-research", "AI research", "-n", "nb_123"])
|
||||
result = runner.invoke(
|
||||
cli, ["source", "add-research", "AI research", "-n", "nb_123"]
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Found 1 sources" in result.output
|
||||
|
|
@ -555,7 +636,9 @@ class TestSourceAddResearch:
|
|||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(cli, ["source", "add-research", "AI research", "-n", "nb_123"])
|
||||
result = runner.invoke(
|
||||
cli, ["source", "add-research", "AI research", "-n", "nb_123"]
|
||||
)
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Research failed to start" in result.output
|
||||
|
|
@ -572,7 +655,9 @@ class TestSourceAddResearch:
|
|||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(cli, ["source", "add-research", "AI research", "-n", "nb_123", "--import-all"])
|
||||
result = runner.invoke(
|
||||
cli, ["source", "add-research", "AI research", "-n", "nb_123", "--import-all"]
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Imported 1 sources" in result.output
|
||||
|
|
@ -619,5 +704,7 @@ class TestNotebookCommandsExist:
|
|||
assert "ask" in result.output
|
||||
# Verify there's no "notebook" command in the Commands section
|
||||
# (it should only appear as part of "NotebookLM" in the description)
|
||||
commands_section = result.output.split("Commands:")[1] if "Commands:" in result.output else ""
|
||||
commands_section = (
|
||||
result.output.split("Commands:")[1] if "Commands:" in result.output else ""
|
||||
)
|
||||
assert " notebook " not in commands_section.lower()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
"""Tests for resolve_notebook_id and resolve_source_id partial ID matching."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import click
|
||||
import pytest
|
||||
|
||||
from notebooklm.cli.helpers import resolve_notebook_id, resolve_source_id
|
||||
from notebooklm.types import Notebook, Source
|
||||
|
|
@ -22,9 +22,24 @@ def mock_client():
|
|||
def sample_notebooks():
|
||||
"""Sample notebooks for testing."""
|
||||
return [
|
||||
Notebook(id="abc123def456ghi789", title="First Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
|
||||
Notebook(id="xyz789uvw456rst123", title="Second Notebook", created_at=datetime(2024, 1, 2), is_owner=False),
|
||||
Notebook(id="abc999zzz888yyy777", title="Third Notebook", created_at=datetime(2024, 1, 3), is_owner=True),
|
||||
Notebook(
|
||||
id="abc123def456ghi789",
|
||||
title="First Notebook",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
is_owner=True,
|
||||
),
|
||||
Notebook(
|
||||
id="xyz789uvw456rst123",
|
||||
title="Second Notebook",
|
||||
created_at=datetime(2024, 1, 2),
|
||||
is_owner=False,
|
||||
),
|
||||
Notebook(
|
||||
id="abc999zzz888yyy777",
|
||||
title="Third Notebook",
|
||||
created_at=datetime(2024, 1, 3),
|
||||
is_owner=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -127,9 +142,16 @@ class TestResolveNotebookId:
|
|||
mock_client.notebooks.list = AsyncMock(return_value=sample_notebooks)
|
||||
|
||||
# Create a notebook with a short ID that we'll match exactly
|
||||
mock_client.notebooks.list = AsyncMock(return_value=[
|
||||
Notebook(id="shortid", title="Short ID Notebook", created_at=datetime(2024, 1, 1), is_owner=True),
|
||||
])
|
||||
mock_client.notebooks.list = AsyncMock(
|
||||
return_value=[
|
||||
Notebook(
|
||||
id="shortid",
|
||||
title="Short ID Notebook",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
is_owner=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
with patch("notebooklm.cli.helpers.console") as mock_console:
|
||||
result = await resolve_notebook_id(mock_client, "shortid")
|
||||
|
|
@ -146,7 +168,12 @@ class TestResolveNotebookIdAmbiguityDisplay:
|
|||
async def test_shows_up_to_five_matches(self, mock_client):
|
||||
"""Ambiguous error shows up to 5 matching notebooks."""
|
||||
notebooks = [
|
||||
Notebook(id=f"abc{i}00000000000000", title=f"Notebook {i}", created_at=datetime(2024, 1, i + 1), is_owner=True)
|
||||
Notebook(
|
||||
id=f"abc{i}00000000000000",
|
||||
title=f"Notebook {i}",
|
||||
created_at=datetime(2024, 1, i + 1),
|
||||
is_owner=True,
|
||||
)
|
||||
for i in range(7)
|
||||
]
|
||||
mock_client.notebooks.list = AsyncMock(return_value=notebooks)
|
||||
|
|
@ -219,7 +246,9 @@ class TestResolveSourceId:
|
|||
mock_console.print.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ambiguous_prefix_raises_exception(self, mock_client_with_sources, sample_sources):
|
||||
async def test_ambiguous_prefix_raises_exception(
|
||||
self, mock_client_with_sources, sample_sources
|
||||
):
|
||||
"""Ambiguous prefix (matches multiple) raises ClickException."""
|
||||
mock_client_with_sources.sources.list = AsyncMock(return_value=sample_sources)
|
||||
|
||||
|
|
@ -318,7 +347,9 @@ class TestResolveSourceIdAmbiguityDisplay:
|
|||
assert "... and 2 more" in error_msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shows_source_titles_in_ambiguous_error(self, mock_client_with_sources, sample_sources):
|
||||
async def test_shows_source_titles_in_ambiguous_error(
|
||||
self, mock_client_with_sources, sample_sources
|
||||
):
|
||||
"""Ambiguous error includes source titles."""
|
||||
mock_client_with_sources.sources.list = AsyncMock(return_value=sample_sources)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
"""Tests for session CLI commands (login, use, status, clear)."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from notebooklm.notebooklm_cli import cli
|
||||
|
|
@ -101,9 +101,7 @@ class TestUseCommand:
|
|||
)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
|
||||
# Patch in session module where it's imported
|
||||
|
|
@ -131,9 +129,7 @@ class TestUseCommand:
|
|||
)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
|
||||
# Patch in session module where it's imported
|
||||
|
|
@ -174,9 +170,7 @@ class TestUseCommand:
|
|||
)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
|
||||
# Patch in session module where it's imported
|
||||
|
|
@ -319,14 +313,10 @@ class TestSessionEdgeCases:
|
|||
"""Test 'use' command handles API errors gracefully."""
|
||||
with patch_main_cli_client() as mock_client_cls:
|
||||
mock_client = create_mock_client()
|
||||
mock_client.notebooks.get = AsyncMock(
|
||||
side_effect=Exception("API Error: Rate limited")
|
||||
)
|
||||
mock_client.notebooks.get = AsyncMock(side_effect=Exception("API Error: Rate limited"))
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock
|
||||
) as mock_fetch:
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
|
||||
# Patch in session module where it's imported
|
||||
|
|
|
|||
|
|
@ -1,11 +1,17 @@
|
|||
"""Tests for skill CLI commands."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from notebooklm.notebooklm_cli import cli
|
||||
|
||||
from .conftest import get_cli_module
|
||||
|
||||
# Get the actual skill module (not the click group that shadows it)
|
||||
skill_module = get_cli_module("skill")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
|
|
@ -20,10 +26,13 @@ class TestSkillInstall:
|
|||
skill_dest = tmp_path / "skills" / "notebooklm" / "SKILL.md"
|
||||
mock_source_content = "---\nname: notebooklm\n---\n# Test"
|
||||
|
||||
with patch("notebooklm.cli.skill.SKILL_DEST", skill_dest), \
|
||||
patch("notebooklm.cli.skill.SKILL_DEST_DIR", skill_dest.parent), \
|
||||
patch("notebooklm.cli.skill.get_skill_source_content", return_value=mock_source_content):
|
||||
|
||||
with (
|
||||
patch.object(skill_module, "SKILL_DEST", skill_dest),
|
||||
patch.object(skill_module, "SKILL_DEST_DIR", skill_dest.parent),
|
||||
patch.object(
|
||||
skill_module, "get_skill_source_content", return_value=mock_source_content
|
||||
),
|
||||
):
|
||||
result = runner.invoke(cli, ["skill", "install"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
|
@ -34,9 +43,11 @@ class TestSkillInstall:
|
|||
"""Test error when source file doesn't exist."""
|
||||
skill_dest = tmp_path / "skills" / "notebooklm" / "SKILL.md"
|
||||
|
||||
with patch("notebooklm.cli.skill.SKILL_DEST", skill_dest), \
|
||||
patch("notebooklm.cli.skill.SKILL_DEST_DIR", skill_dest.parent), \
|
||||
patch("notebooklm.cli.skill.get_skill_source_content", return_value=None):
|
||||
with (
|
||||
patch.object(skill_module, "SKILL_DEST", skill_dest),
|
||||
patch.object(skill_module, "SKILL_DEST_DIR", skill_dest.parent),
|
||||
patch.object(skill_module, "get_skill_source_content", return_value=None),
|
||||
):
|
||||
result = runner.invoke(cli, ["skill", "install"])
|
||||
|
||||
assert result.exit_code == 1
|
||||
|
|
@ -50,7 +61,7 @@ class TestSkillStatus:
|
|||
"""Test status when skill is not installed."""
|
||||
skill_dest = tmp_path / "skills" / "notebooklm" / "SKILL.md"
|
||||
|
||||
with patch("notebooklm.cli.skill.SKILL_DEST", skill_dest):
|
||||
with patch.object(skill_module, "SKILL_DEST", skill_dest):
|
||||
result = runner.invoke(cli, ["skill", "status"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
|
@ -62,7 +73,7 @@ class TestSkillStatus:
|
|||
skill_dest.parent.mkdir(parents=True)
|
||||
skill_dest.write_text("<!-- notebooklm-py v0.1.0 -->\n# Test")
|
||||
|
||||
with patch("notebooklm.cli.skill.SKILL_DEST", skill_dest):
|
||||
with patch.object(skill_module, "SKILL_DEST", skill_dest):
|
||||
result = runner.invoke(cli, ["skill", "status"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
|
@ -78,8 +89,10 @@ class TestSkillUninstall:
|
|||
skill_dest.parent.mkdir(parents=True)
|
||||
skill_dest.write_text("# Test")
|
||||
|
||||
with patch("notebooklm.cli.skill.SKILL_DEST", skill_dest), \
|
||||
patch("notebooklm.cli.skill.SKILL_DEST_DIR", skill_dest.parent):
|
||||
with (
|
||||
patch.object(skill_module, "SKILL_DEST", skill_dest),
|
||||
patch.object(skill_module, "SKILL_DEST_DIR", skill_dest.parent),
|
||||
):
|
||||
result = runner.invoke(cli, ["skill", "uninstall"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
|
@ -89,7 +102,7 @@ class TestSkillUninstall:
|
|||
"""Test uninstall when skill doesn't exist."""
|
||||
skill_dest = tmp_path / "skills" / "notebooklm" / "SKILL.md"
|
||||
|
||||
with patch("notebooklm.cli.skill.SKILL_DEST", skill_dest):
|
||||
with patch.object(skill_module, "SKILL_DEST", skill_dest):
|
||||
result = runner.invoke(cli, ["skill", "uninstall"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
|
@ -105,7 +118,7 @@ class TestSkillShow:
|
|||
skill_dest.parent.mkdir(parents=True)
|
||||
skill_dest.write_text("# NotebookLM Skill\nTest content")
|
||||
|
||||
with patch("notebooklm.cli.skill.SKILL_DEST", skill_dest):
|
||||
with patch.object(skill_module, "SKILL_DEST", skill_dest):
|
||||
result = runner.invoke(cli, ["skill", "show"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
|
@ -115,7 +128,7 @@ class TestSkillShow:
|
|||
"""Test show when skill doesn't exist."""
|
||||
skill_dest = tmp_path / "skills" / "notebooklm" / "SKILL.md"
|
||||
|
||||
with patch("notebooklm.cli.skill.SKILL_DEST", skill_dest):
|
||||
with patch.object(skill_module, "SKILL_DEST", skill_dest):
|
||||
result = runner.invoke(cli, ["skill", "show"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
"""Tests for source CLI commands."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from notebooklm.notebooklm_cli import cli
|
||||
|
|
@ -60,12 +59,15 @@ class TestSourceList:
|
|||
mock_client = create_mock_client()
|
||||
mock_client.sources.list = AsyncMock(
|
||||
return_value=[
|
||||
Source(id="src_1", title="Test Source", source_type="url", url="https://example.com"),
|
||||
Source(
|
||||
id="src_1",
|
||||
title="Test Source",
|
||||
source_type="url",
|
||||
url="https://example.com",
|
||||
),
|
||||
]
|
||||
)
|
||||
mock_client.notebooks.get = AsyncMock(
|
||||
return_value=MagicMock(title="Test Notebook")
|
||||
)
|
||||
mock_client.notebooks.get = AsyncMock(return_value=MagicMock(title="Test Notebook"))
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
|
|
@ -89,7 +91,9 @@ class TestSourceAdd:
|
|||
with patch_client_for_module("source") as mock_client_cls:
|
||||
mock_client = create_mock_client()
|
||||
mock_client.sources.add_url = AsyncMock(
|
||||
return_value=Source(id="src_new", title="Example", url="https://example.com", source_type="url")
|
||||
return_value=Source(
|
||||
id="src_new", title="Example", url="https://example.com", source_type="url"
|
||||
)
|
||||
)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
|
|
@ -105,7 +109,12 @@ class TestSourceAdd:
|
|||
with patch_client_for_module("source") as mock_client_cls:
|
||||
mock_client = create_mock_client()
|
||||
mock_client.sources.add_url = AsyncMock(
|
||||
return_value=Source(id="src_yt", title="YouTube Video", url="https://youtube.com/watch?v=abc", source_type="youtube")
|
||||
return_value=Source(
|
||||
id="src_yt",
|
||||
title="YouTube Video",
|
||||
url="https://youtube.com/watch?v=abc",
|
||||
source_type="youtube",
|
||||
)
|
||||
)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
|
|
@ -147,7 +156,17 @@ class TestSourceAdd:
|
|||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
["source", "add", "My notes", "--type", "text", "--title", "Custom Title", "-n", "nb_123"],
|
||||
[
|
||||
"source",
|
||||
"add",
|
||||
"My notes",
|
||||
"--type",
|
||||
"text",
|
||||
"--title",
|
||||
"Custom Title",
|
||||
"-n",
|
||||
"nb_123",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
|
@ -177,7 +196,9 @@ class TestSourceAdd:
|
|||
with patch_client_for_module("source") as mock_client_cls:
|
||||
mock_client = create_mock_client()
|
||||
mock_client.sources.add_url = AsyncMock(
|
||||
return_value=Source(id="src_new", title="Example", url="https://example.com", source_type="url")
|
||||
return_value=Source(
|
||||
id="src_new", title="Example", url="https://example.com", source_type="url"
|
||||
)
|
||||
)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
|
|
@ -203,9 +224,7 @@ class TestSourceGet:
|
|||
mock_client = create_mock_client()
|
||||
# Mock sources.list for resolve_source_id
|
||||
mock_client.sources.list = AsyncMock(
|
||||
return_value=[
|
||||
Source(id="src_123", title="Test Source", source_type="url")
|
||||
]
|
||||
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
|
||||
)
|
||||
mock_client.sources.get = AsyncMock(
|
||||
return_value=Source(
|
||||
|
|
@ -213,7 +232,7 @@ class TestSourceGet:
|
|||
title="Test Source",
|
||||
source_type="url",
|
||||
url="https://example.com",
|
||||
created_at=datetime(2024, 1, 1)
|
||||
created_at=datetime(2024, 1, 1),
|
||||
)
|
||||
)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
|
@ -254,18 +273,14 @@ class TestSourceDelete:
|
|||
mock_client = create_mock_client()
|
||||
# Mock sources.list for resolve_source_id
|
||||
mock_client.sources.list = AsyncMock(
|
||||
return_value=[
|
||||
Source(id="src_123", title="Test Source", source_type="url")
|
||||
]
|
||||
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
|
||||
)
|
||||
mock_client.sources.delete = AsyncMock(return_value=True)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
cli, ["source", "delete", "src_123", "-n", "nb_123", "-y"]
|
||||
)
|
||||
result = runner.invoke(cli, ["source", "delete", "src_123", "-n", "nb_123", "-y"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Deleted source" in result.output
|
||||
|
|
@ -276,18 +291,14 @@ class TestSourceDelete:
|
|||
mock_client = create_mock_client()
|
||||
# Mock sources.list for resolve_source_id
|
||||
mock_client.sources.list = AsyncMock(
|
||||
return_value=[
|
||||
Source(id="src_123", title="Test Source", source_type="url")
|
||||
]
|
||||
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
|
||||
)
|
||||
mock_client.sources.delete = AsyncMock(return_value=False)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
cli, ["source", "delete", "src_123", "-n", "nb_123", "-y"]
|
||||
)
|
||||
result = runner.invoke(cli, ["source", "delete", "src_123", "-n", "nb_123", "-y"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Delete may have failed" in result.output
|
||||
|
|
@ -304,9 +315,7 @@ class TestSourceRename:
|
|||
mock_client = create_mock_client()
|
||||
# Mock sources.list for resolve_source_id
|
||||
mock_client.sources.list = AsyncMock(
|
||||
return_value=[
|
||||
Source(id="src_123", title="Old Title", source_type="url")
|
||||
]
|
||||
return_value=[Source(id="src_123", title="Old Title", source_type="url")]
|
||||
)
|
||||
mock_client.sources.rename = AsyncMock(
|
||||
return_value=Source(id="src_123", title="New Title", source_type="url")
|
||||
|
|
@ -335,9 +344,7 @@ class TestSourceRefresh:
|
|||
mock_client = create_mock_client()
|
||||
# Mock sources.list for resolve_source_id
|
||||
mock_client.sources.list = AsyncMock(
|
||||
return_value=[
|
||||
Source(id="src_123", title="Original Source", source_type="url")
|
||||
]
|
||||
return_value=[Source(id="src_123", title="Original Source", source_type="url")]
|
||||
)
|
||||
mock_client.sources.refresh = AsyncMock(
|
||||
return_value=Source(id="src_123", title="Refreshed Source", source_type="url")
|
||||
|
|
@ -356,9 +363,7 @@ class TestSourceRefresh:
|
|||
mock_client = create_mock_client()
|
||||
# Mock sources.list for resolve_source_id
|
||||
mock_client.sources.list = AsyncMock(
|
||||
return_value=[
|
||||
Source(id="src_123", title="Original Source", source_type="url")
|
||||
]
|
||||
return_value=[Source(id="src_123", title="Original Source", source_type="url")]
|
||||
)
|
||||
mock_client.sources.refresh = AsyncMock(return_value=None)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
|
@ -405,7 +410,17 @@ class TestSourceAddDrive:
|
|||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(
|
||||
cli, ["source", "add-drive", "file_id", "PDF Title", "--mime-type", "pdf", "-n", "nb_123"]
|
||||
cli,
|
||||
[
|
||||
"source",
|
||||
"add-drive",
|
||||
"file_id",
|
||||
"PDF Title",
|
||||
"--mime-type",
|
||||
"pdf",
|
||||
"-n",
|
||||
"nb_123",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
|
@ -426,14 +441,12 @@ class TestSourceGuide:
|
|||
with patch_client_for_module("source") as mock_client_cls:
|
||||
mock_client = create_mock_client()
|
||||
mock_client.sources.list = AsyncMock(
|
||||
return_value=[
|
||||
Source(id="src_123", title="Test Source", source_type="url")
|
||||
]
|
||||
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
|
||||
)
|
||||
mock_client.sources.get_guide = AsyncMock(
|
||||
return_value={
|
||||
"summary": "This is a **test** summary about AI.",
|
||||
"keywords": ["AI", "machine learning", "data science"]
|
||||
"keywords": ["AI", "machine learning", "data science"],
|
||||
}
|
||||
)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
|
@ -452,13 +465,9 @@ class TestSourceGuide:
|
|||
with patch_client_for_module("source") as mock_client_cls:
|
||||
mock_client = create_mock_client()
|
||||
mock_client.sources.list = AsyncMock(
|
||||
return_value=[
|
||||
Source(id="src_123", title="Test Source", source_type="url")
|
||||
]
|
||||
)
|
||||
mock_client.sources.get_guide = AsyncMock(
|
||||
return_value={"summary": "", "keywords": []}
|
||||
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
|
||||
)
|
||||
mock_client.sources.get_guide = AsyncMock(return_value={"summary": "", "keywords": []})
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
|
|
@ -472,21 +481,18 @@ class TestSourceGuide:
|
|||
with patch_client_for_module("source") as mock_client_cls:
|
||||
mock_client = create_mock_client()
|
||||
mock_client.sources.list = AsyncMock(
|
||||
return_value=[
|
||||
Source(id="src_123", title="Test Source", source_type="url")
|
||||
]
|
||||
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
|
||||
)
|
||||
mock_client.sources.get_guide = AsyncMock(
|
||||
return_value={
|
||||
"summary": "Test summary",
|
||||
"keywords": ["keyword1", "keyword2"]
|
||||
}
|
||||
return_value={"summary": "Test summary", "keywords": ["keyword1", "keyword2"]}
|
||||
)
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = ("csrf", "session")
|
||||
result = runner.invoke(cli, ["source", "guide", "src_123", "-n", "nb_123", "--json"])
|
||||
result = runner.invoke(
|
||||
cli, ["source", "guide", "src_123", "-n", "nb_123", "--json"]
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
data = json.loads(result.output)
|
||||
|
|
@ -499,9 +505,7 @@ class TestSourceGuide:
|
|||
with patch_client_for_module("source") as mock_client_cls:
|
||||
mock_client = create_mock_client()
|
||||
mock_client.sources.list = AsyncMock(
|
||||
return_value=[
|
||||
Source(id="src_123", title="Test Source", source_type="url")
|
||||
]
|
||||
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
|
||||
)
|
||||
mock_client.sources.get_guide = AsyncMock(
|
||||
return_value={"summary": "Summary without keywords", "keywords": []}
|
||||
|
|
@ -522,9 +526,7 @@ class TestSourceGuide:
|
|||
with patch_client_for_module("source") as mock_client_cls:
|
||||
mock_client = create_mock_client()
|
||||
mock_client.sources.list = AsyncMock(
|
||||
return_value=[
|
||||
Source(id="src_123", title="Test Source", source_type="url")
|
||||
]
|
||||
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
|
||||
)
|
||||
mock_client.sources.get_guide = AsyncMock(
|
||||
return_value={"summary": "", "keywords": ["AI", "ML", "Data"]}
|
||||
|
|
@ -552,9 +554,7 @@ class TestSourceStale:
|
|||
with patch_client_for_module("source") as mock_client_cls:
|
||||
mock_client = create_mock_client()
|
||||
mock_client.sources.list = AsyncMock(
|
||||
return_value=[
|
||||
Source(id="src_123", title="Test Source", source_type="url")
|
||||
]
|
||||
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
|
||||
)
|
||||
mock_client.sources.check_freshness = AsyncMock(return_value=False) # Not fresh = stale
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
|
@ -572,9 +572,7 @@ class TestSourceStale:
|
|||
with patch_client_for_module("source") as mock_client_cls:
|
||||
mock_client = create_mock_client()
|
||||
mock_client.sources.list = AsyncMock(
|
||||
return_value=[
|
||||
Source(id="src_123", title="Test Source", source_type="url")
|
||||
]
|
||||
return_value=[Source(id="src_123", title="Test Source", source_type="url")]
|
||||
)
|
||||
mock_client.sources.check_freshness = AsyncMock(return_value=True) # Fresh
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
|
|
|||
|
|
@ -1,15 +1,16 @@
|
|||
"""Unit tests for new API coverage features."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from notebooklm import NotebookLMClient
|
||||
from notebooklm.auth import AuthTokens
|
||||
from notebooklm.rpc.types import (
|
||||
RPCMethod,
|
||||
ChatGoal,
|
||||
ChatResponseLength,
|
||||
DriveMimeType,
|
||||
RPCMethod,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
"""Unit tests for artifact download methods."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
import tempfile
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from notebooklm import NotebookLMClient
|
||||
from notebooklm._artifacts import ArtifactsAPI
|
||||
from notebooklm.auth import AuthTokens
|
||||
|
||||
|
|
@ -42,26 +42,34 @@ class TestDownloadAudio:
|
|||
"""Test successful audio download."""
|
||||
api, mock_core = mock_artifacts_api
|
||||
# Mock artifact list response - type 1 (audio), status 3 (completed)
|
||||
mock_core.rpc_call.return_value = [[
|
||||
mock_core.rpc_call.return_value = [
|
||||
[
|
||||
"audio_001", # id
|
||||
"Audio Title", # title
|
||||
1, # type (audio)
|
||||
None, # ?
|
||||
3, # status (completed)
|
||||
None, # ?
|
||||
[None, None, None, None, None, [ # metadata[6][5] = media list
|
||||
["https://example.com/audio.mp4", None, "audio/mp4"]
|
||||
]],
|
||||
[
|
||||
"audio_001", # id
|
||||
"Audio Title", # title
|
||||
1, # type (audio)
|
||||
None, # ?
|
||||
3, # status (completed)
|
||||
None, # ?
|
||||
[
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
[ # metadata[6][5] = media list
|
||||
["https://example.com/audio.mp4", None, "audio/mp4"]
|
||||
],
|
||||
],
|
||||
]
|
||||
]
|
||||
]]
|
||||
]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
output_path = os.path.join(tmpdir, "audio.mp4")
|
||||
|
||||
with patch.object(
|
||||
api, '_download_url',
|
||||
new_callable=AsyncMock, return_value=output_path
|
||||
api, "_download_url", new_callable=AsyncMock, return_value=output_path
|
||||
):
|
||||
result = await api.download_audio("nb_123", output_path)
|
||||
|
||||
|
|
@ -80,22 +88,20 @@ class TestDownloadAudio:
|
|||
async def test_download_audio_specific_id_not_found(self, mock_artifacts_api):
|
||||
"""Test error when specific audio ID not found."""
|
||||
api, mock_core = mock_artifacts_api
|
||||
mock_core.rpc_call.return_value = [[
|
||||
["other_id", "Audio", 1, None, 3, None, [None] * 6]
|
||||
]]
|
||||
mock_core.rpc_call.return_value = [[["other_id", "Audio", 1, None, 3, None, [None] * 6]]]
|
||||
|
||||
with pytest.raises(ValueError, match="Audio artifact audio_001 not found"):
|
||||
await api.download_audio(
|
||||
"nb_123", "/tmp/audio.mp4", artifact_id="audio_001"
|
||||
)
|
||||
await api.download_audio("nb_123", "/tmp/audio.mp4", artifact_id="audio_001")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_audio_invalid_metadata(self, mock_artifacts_api):
|
||||
"""Test error on invalid metadata structure."""
|
||||
api, mock_core = mock_artifacts_api
|
||||
mock_core.rpc_call.return_value = [[
|
||||
["audio_001", "Audio", 1, None, 3, None, "not_a_list"] # metadata should be list
|
||||
]]
|
||||
mock_core.rpc_call.return_value = [
|
||||
[
|
||||
["audio_001", "Audio", 1, None, 3, None, "not_a_list"] # metadata should be list
|
||||
]
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid audio metadata|Failed to parse"):
|
||||
await api.download_audio("nb_123", "/tmp/audio.mp4")
|
||||
|
|
@ -113,16 +119,24 @@ class TestDownloadVideo:
|
|||
output_path = os.path.join(tmpdir, "video.mp4")
|
||||
|
||||
# Patch _list_raw to return video artifact data
|
||||
with patch.object(api, '_list_raw', new_callable=AsyncMock) as mock_list:
|
||||
with patch.object(api, "_list_raw", new_callable=AsyncMock) as mock_list:
|
||||
# Type 3 (video), status 3 (completed), metadata at index 8
|
||||
mock_list.return_value = [[
|
||||
"video_001", "Video Title", 3, None, 3, None, None, None,
|
||||
[[["https://example.com/video.mp4", 4, "video/mp4"]]]
|
||||
]]
|
||||
mock_list.return_value = [
|
||||
[
|
||||
"video_001",
|
||||
"Video Title",
|
||||
3,
|
||||
None,
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
[[["https://example.com/video.mp4", 4, "video/mp4"]]],
|
||||
]
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
api, '_download_url',
|
||||
new_callable=AsyncMock, return_value=output_path
|
||||
api, "_download_url", new_callable=AsyncMock, return_value=output_path
|
||||
):
|
||||
result = await api.download_video("nb_123", output_path)
|
||||
|
||||
|
|
@ -133,7 +147,7 @@ class TestDownloadVideo:
|
|||
"""Test error when no video artifact exists."""
|
||||
api, mock_core = mock_artifacts_api
|
||||
|
||||
with patch.object(api, '_list_raw', new_callable=AsyncMock) as mock_list:
|
||||
with patch.object(api, "_list_raw", new_callable=AsyncMock) as mock_list:
|
||||
mock_list.return_value = []
|
||||
|
||||
with pytest.raises(ValueError, match="No completed video"):
|
||||
|
|
@ -144,15 +158,11 @@ class TestDownloadVideo:
|
|||
"""Test error when specific video ID not found."""
|
||||
api, mock_core = mock_artifacts_api
|
||||
|
||||
with patch.object(api, '_list_raw', new_callable=AsyncMock) as mock_list:
|
||||
mock_list.return_value = [
|
||||
["other_id", "Video", 3, None, 3, None, None, None, []]
|
||||
]
|
||||
with patch.object(api, "_list_raw", new_callable=AsyncMock) as mock_list:
|
||||
mock_list.return_value = [["other_id", "Video", 3, None, 3, None, None, None, []]]
|
||||
|
||||
with pytest.raises(ValueError, match="Video artifact video_001 not found"):
|
||||
await api.download_video(
|
||||
"nb_123", "/tmp/video.mp4", artifact_id="video_001"
|
||||
)
|
||||
await api.download_video("nb_123", "/tmp/video.mp4", artifact_id="video_001")
|
||||
|
||||
|
||||
class TestDownloadInfographic:
|
||||
|
|
@ -167,17 +177,25 @@ class TestDownloadInfographic:
|
|||
output_path = os.path.join(tmpdir, "infographic.png")
|
||||
|
||||
# Patch _list_raw to return infographic data
|
||||
with patch.object(api, '_list_raw', new_callable=AsyncMock) as mock_list:
|
||||
with patch.object(api, "_list_raw", new_callable=AsyncMock) as mock_list:
|
||||
# Type 7 (infographic), status 3, metadata with nested URL structure
|
||||
mock_list.return_value = [[
|
||||
"infographic_001", "Infographic Title", 7, None, 3,
|
||||
None, None, None, None,
|
||||
[[], [], [[None, ["https://example.com/infographic.png"]]]]
|
||||
]]
|
||||
mock_list.return_value = [
|
||||
[
|
||||
"infographic_001",
|
||||
"Infographic Title",
|
||||
7,
|
||||
None,
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
[[], [], [[None, ["https://example.com/infographic.png"]]]],
|
||||
]
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
api, '_download_url',
|
||||
new_callable=AsyncMock, return_value=output_path
|
||||
api, "_download_url", new_callable=AsyncMock, return_value=output_path
|
||||
):
|
||||
result = await api.download_infographic("nb_123", output_path)
|
||||
|
||||
|
|
@ -188,7 +206,7 @@ class TestDownloadInfographic:
|
|||
"""Test error when no infographic artifact exists."""
|
||||
api, mock_core = mock_artifacts_api
|
||||
|
||||
with patch.object(api, '_list_raw', new_callable=AsyncMock) as mock_list:
|
||||
with patch.object(api, "_list_raw", new_callable=AsyncMock) as mock_list:
|
||||
mock_list.return_value = []
|
||||
|
||||
with pytest.raises(ValueError, match="No completed infographic"):
|
||||
|
|
@ -208,23 +226,24 @@ class TestDownloadSlideDeck:
|
|||
|
||||
# Patch _list_raw to return slide deck artifact data
|
||||
# Structure: artifact[16] = [config, title, slides_list, pdf_url]
|
||||
with patch.object(api, '_list_raw', new_callable=AsyncMock) as mock_list:
|
||||
with patch.object(api, "_list_raw", new_callable=AsyncMock) as mock_list:
|
||||
# Create artifact with 17+ elements, type 8 (slide deck), status 3
|
||||
artifact = ["slide_001", "Slide Deck Title", 8, None, 3]
|
||||
# Pad to index 16
|
||||
artifact.extend([None] * 11)
|
||||
# Index 16: metadata with PDF URL at position 3
|
||||
artifact.append([
|
||||
["config"],
|
||||
"Slide Deck Title",
|
||||
[["slide1"], ["slide2"]], # slides_list
|
||||
"https://contribution.usercontent.google.com/download?filename=test.pdf"
|
||||
])
|
||||
artifact.append(
|
||||
[
|
||||
["config"],
|
||||
"Slide Deck Title",
|
||||
[["slide1"], ["slide2"]], # slides_list
|
||||
"https://contribution.usercontent.google.com/download?filename=test.pdf",
|
||||
]
|
||||
)
|
||||
mock_list.return_value = [artifact]
|
||||
|
||||
with patch.object(
|
||||
api, '_download_url',
|
||||
new_callable=AsyncMock, return_value=output_path
|
||||
api, "_download_url", new_callable=AsyncMock, return_value=output_path
|
||||
):
|
||||
result = await api.download_slide_deck("nb_123", output_path)
|
||||
|
||||
|
|
@ -235,7 +254,7 @@ class TestDownloadSlideDeck:
|
|||
"""Test error when no slide deck artifact exists."""
|
||||
api, mock_core = mock_artifacts_api
|
||||
|
||||
with patch.object(api, '_list_raw', new_callable=AsyncMock) as mock_list:
|
||||
with patch.object(api, "_list_raw", new_callable=AsyncMock) as mock_list:
|
||||
mock_list.return_value = []
|
||||
|
||||
with pytest.raises(ValueError, match="No completed slide"):
|
||||
|
|
@ -246,7 +265,7 @@ class TestDownloadSlideDeck:
|
|||
"""Test error when specific slide deck ID not found."""
|
||||
api, mock_core = mock_artifacts_api
|
||||
|
||||
with patch.object(api, '_list_raw', new_callable=AsyncMock) as mock_list:
|
||||
with patch.object(api, "_list_raw", new_callable=AsyncMock) as mock_list:
|
||||
# Need at least 17 elements for valid structure
|
||||
artifact = ["other_id", "Slides", 8, None, 3]
|
||||
artifact.extend([None] * 11)
|
||||
|
|
@ -261,7 +280,7 @@ class TestDownloadSlideDeck:
|
|||
"""Test error on invalid metadata structure."""
|
||||
api, mock_core = mock_artifacts_api
|
||||
|
||||
with patch.object(api, '_list_raw', new_callable=AsyncMock) as mock_list:
|
||||
with patch.object(api, "_list_raw", new_callable=AsyncMock) as mock_list:
|
||||
# Create artifact with invalid metadata (less than 4 elements)
|
||||
artifact = ["slide_001", "Slides", 8, None, 3]
|
||||
artifact.extend([None] * 11)
|
||||
|
|
@ -284,11 +303,13 @@ class TestMindMapGeneration:
|
|||
# _get_source_ids response
|
||||
[None, None, None, None, None, [[["src_001"]]]],
|
||||
# generate_mind_map response
|
||||
[[
|
||||
'{"nodes": [{"id": "1", "text": "Root"}]}', # JSON string
|
||||
None,
|
||||
["note_123"], # note info (not used anymore, note is created explicitly)
|
||||
]],
|
||||
[
|
||||
[
|
||||
'{"nodes": [{"id": "1", "text": "Root"}]}', # JSON string
|
||||
None,
|
||||
["note_123"], # note info (not used anymore, note is created explicitly)
|
||||
]
|
||||
],
|
||||
]
|
||||
|
||||
result = await api.generate_mind_map("nb_123")
|
||||
|
|
@ -304,11 +325,13 @@ class TestMindMapGeneration:
|
|||
api, mock_core = mock_artifacts_api
|
||||
mock_core.rpc_call.side_effect = [
|
||||
[None, None, None, None, None, [[["src_001"]]]],
|
||||
[[
|
||||
{"nodes": [{"id": "1"}]}, # Already a dict
|
||||
None,
|
||||
["note_456"], # note info (not used anymore)
|
||||
]],
|
||||
[
|
||||
[
|
||||
{"nodes": [{"id": "1"}]}, # Already a dict
|
||||
None,
|
||||
["note_456"], # note info (not used anymore)
|
||||
]
|
||||
],
|
||||
]
|
||||
|
||||
result = await api.generate_mind_map("nb_123")
|
||||
|
|
@ -359,12 +382,10 @@ class TestDownloadUrl:
|
|||
|
||||
# Mock load_httpx_cookies to avoid requiring real auth files
|
||||
mock_cookies = MagicMock()
|
||||
with patch.object(real_httpx, 'AsyncClient', return_value=mock_client), \
|
||||
patch('notebooklm._artifacts.load_httpx_cookies', return_value=mock_cookies):
|
||||
result = await api._download_url(
|
||||
"https://other.example.com/file.mp4", output_path
|
||||
)
|
||||
with (
|
||||
patch.object(real_httpx, "AsyncClient", return_value=mock_client),
|
||||
patch("notebooklm._artifacts.load_httpx_cookies", return_value=mock_cookies),
|
||||
):
|
||||
result = await api._download_url("https://other.example.com/file.mp4", output_path)
|
||||
|
||||
assert result == output_path
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
"""Tests for authentication module."""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pytest_httpx import HTTPXMock
|
||||
|
||||
from notebooklm.auth import (
|
||||
|
|
@ -10,9 +11,9 @@ from notebooklm.auth import (
|
|||
extract_cookies_from_storage,
|
||||
extract_csrf_from_html,
|
||||
extract_session_id_from_html,
|
||||
fetch_tokens,
|
||||
load_auth_from_storage,
|
||||
load_httpx_cookies,
|
||||
fetch_tokens,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -254,7 +255,9 @@ class TestLoadAuthFromEnvVar:
|
|||
|
||||
# Set NOTEBOOKLM_HOME to tmp_path and create a file there
|
||||
monkeypatch.setenv("NOTEBOOKLM_HOME", str(tmp_path))
|
||||
file_storage = {"cookies": [{"name": "SID", "value": "from_home_file", "domain": ".google.com"}]}
|
||||
file_storage = {
|
||||
"cookies": [{"name": "SID", "value": "from_home_file", "domain": ".google.com"}]
|
||||
}
|
||||
storage_file = tmp_path / "storage_state.json"
|
||||
storage_file.write_text(json.dumps(file_storage))
|
||||
|
||||
|
|
@ -266,28 +269,36 @@ class TestLoadAuthFromEnvVar:
|
|||
"""Test that empty string NOTEBOOKLM_AUTH_JSON raises ValueError."""
|
||||
monkeypatch.setenv("NOTEBOOKLM_AUTH_JSON", "")
|
||||
|
||||
with pytest.raises(ValueError, match="NOTEBOOKLM_AUTH_JSON environment variable is set but empty"):
|
||||
with pytest.raises(
|
||||
ValueError, match="NOTEBOOKLM_AUTH_JSON environment variable is set but empty"
|
||||
):
|
||||
load_auth_from_storage()
|
||||
|
||||
def test_env_var_whitespace_only_raises_value_error(self, monkeypatch):
|
||||
"""Test that whitespace-only NOTEBOOKLM_AUTH_JSON raises ValueError."""
|
||||
monkeypatch.setenv("NOTEBOOKLM_AUTH_JSON", " \n\t ")
|
||||
|
||||
with pytest.raises(ValueError, match="NOTEBOOKLM_AUTH_JSON environment variable is set but empty"):
|
||||
with pytest.raises(
|
||||
ValueError, match="NOTEBOOKLM_AUTH_JSON environment variable is set but empty"
|
||||
):
|
||||
load_auth_from_storage()
|
||||
|
||||
def test_env_var_missing_cookies_key_raises_value_error(self, monkeypatch):
|
||||
"""Test that NOTEBOOKLM_AUTH_JSON without 'cookies' key raises ValueError."""
|
||||
monkeypatch.setenv("NOTEBOOKLM_AUTH_JSON", '{"origins": []}')
|
||||
|
||||
with pytest.raises(ValueError, match="must contain valid Playwright storage state with a 'cookies' key"):
|
||||
with pytest.raises(
|
||||
ValueError, match="must contain valid Playwright storage state with a 'cookies' key"
|
||||
):
|
||||
load_auth_from_storage()
|
||||
|
||||
def test_env_var_non_dict_raises_value_error(self, monkeypatch):
|
||||
"""Test that non-dict NOTEBOOKLM_AUTH_JSON raises ValueError."""
|
||||
monkeypatch.setenv("NOTEBOOKLM_AUTH_JSON", '["not", "a", "dict"]')
|
||||
|
||||
with pytest.raises(ValueError, match="must contain valid Playwright storage state with a 'cookies' key"):
|
||||
with pytest.raises(
|
||||
ValueError, match="must contain valid Playwright storage state with a 'cookies' key"
|
||||
):
|
||||
load_auth_from_storage()
|
||||
|
||||
|
||||
|
|
@ -327,7 +338,9 @@ class TestLoadHttpxCookiesWithEnvVar:
|
|||
"""Test that empty string NOTEBOOKLM_AUTH_JSON raises ValueError."""
|
||||
monkeypatch.setenv("NOTEBOOKLM_AUTH_JSON", "")
|
||||
|
||||
with pytest.raises(ValueError, match="NOTEBOOKLM_AUTH_JSON environment variable is set but empty"):
|
||||
with pytest.raises(
|
||||
ValueError, match="NOTEBOOKLM_AUTH_JSON environment variable is set but empty"
|
||||
):
|
||||
load_httpx_cookies()
|
||||
|
||||
def test_env_var_missing_required_cookies_raises(self, monkeypatch):
|
||||
|
|
@ -640,7 +653,6 @@ class TestDefaultStoragePath:
|
|||
def test_default_storage_path_is_correct(self):
|
||||
"""Test DEFAULT_STORAGE_PATH constant is defined correctly."""
|
||||
from notebooklm.auth import DEFAULT_STORAGE_PATH
|
||||
from pathlib import Path
|
||||
|
||||
assert DEFAULT_STORAGE_PATH is not None
|
||||
assert isinstance(DEFAULT_STORAGE_PATH, Path)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,12 @@
|
|||
"""Tests for NotebookLMClient class."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest_httpx import HTTPXMock
|
||||
|
||||
from notebooklm.client import NotebookLMClient
|
||||
from notebooklm.auth import AuthTokens
|
||||
from notebooklm.client import NotebookLMClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -172,7 +170,7 @@ class TestRefreshAuth:
|
|||
client = NotebookLMClient(mock_auth)
|
||||
|
||||
# Mock the homepage response with new tokens
|
||||
html = '''
|
||||
html = """
|
||||
<html>
|
||||
<script>
|
||||
window.WIZ_global_data = {
|
||||
|
|
@ -181,7 +179,7 @@ class TestRefreshAuth:
|
|||
};
|
||||
</script>
|
||||
</html>
|
||||
'''
|
||||
"""
|
||||
httpx_mock.add_response(
|
||||
url="https://notebooklm.google.com/",
|
||||
content=html.encode(),
|
||||
|
|
@ -208,7 +206,7 @@ class TestRefreshAuth:
|
|||
# The refresh_auth checks if "accounts.google.com" is in the final URL
|
||||
# We can't easily mock a real redirect with httpx, so we test the URL check
|
||||
# by providing a response that doesn't contain the expected tokens
|
||||
html = '<html><body>Please sign in</body></html>' # No tokens
|
||||
html = "<html><body>Please sign in</body></html>" # No tokens
|
||||
httpx_mock.add_response(
|
||||
url="https://notebooklm.google.com/",
|
||||
content=html.encode(),
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
"""Tests for conversation functionality."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
import json
|
||||
|
||||
from notebooklm import NotebookLMClient, AskResult
|
||||
import pytest
|
||||
|
||||
from notebooklm import AskResult, NotebookLMClient
|
||||
from notebooklm.auth import AuthTokens
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,16 @@
|
|||
"""Unit tests for RPC response decoder."""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from notebooklm.rpc.decoder import (
|
||||
strip_anti_xssi,
|
||||
parse_chunked_response,
|
||||
extract_rpc_result,
|
||||
RPCError,
|
||||
collect_rpc_ids,
|
||||
decode_response,
|
||||
RPCError,
|
||||
extract_rpc_result,
|
||||
parse_chunked_response,
|
||||
strip_anti_xssi,
|
||||
)
|
||||
from notebooklm.rpc.types import RPCMethod
|
||||
|
||||
|
|
@ -192,9 +193,7 @@ class TestExtractRPCResult:
|
|||
def test_user_displayable_error_sets_code(self):
|
||||
"""Test UserDisplayableError sets code to USER_DISPLAYABLE_ERROR."""
|
||||
error_info = [8, None, [["UserDisplayableError", []]]]
|
||||
chunks = [
|
||||
["wrb.fr", RPCMethod.LIST_NOTEBOOKS.value, None, None, None, error_info]
|
||||
]
|
||||
chunks = [["wrb.fr", RPCMethod.LIST_NOTEBOOKS.value, None, None, None, error_info]]
|
||||
|
||||
with pytest.raises(RPCError) as exc_info:
|
||||
extract_rpc_result(chunks, RPCMethod.LIST_NOTEBOOKS.value)
|
||||
|
|
@ -203,18 +202,14 @@ class TestExtractRPCResult:
|
|||
|
||||
def test_null_result_without_error_info_returns_none(self):
|
||||
"""Test null result without UserDisplayableError returns None normally."""
|
||||
chunks = [
|
||||
["wrb.fr", RPCMethod.LIST_NOTEBOOKS.value, None, None, None, None]
|
||||
]
|
||||
chunks = [["wrb.fr", RPCMethod.LIST_NOTEBOOKS.value, None, None, None, None]]
|
||||
|
||||
result = extract_rpc_result(chunks, RPCMethod.LIST_NOTEBOOKS.value)
|
||||
assert result is None
|
||||
|
||||
def test_null_result_with_non_error_info_returns_none(self):
|
||||
"""Test null result with non-error data at index 5 returns None."""
|
||||
chunks = [
|
||||
["wrb.fr", RPCMethod.LIST_NOTEBOOKS.value, None, None, None, [1, 2, 3]]
|
||||
]
|
||||
chunks = [["wrb.fr", RPCMethod.LIST_NOTEBOOKS.value, None, None, None, [1, 2, 3]]]
|
||||
|
||||
result = extract_rpc_result(chunks, RPCMethod.LIST_NOTEBOOKS.value)
|
||||
assert result is None
|
||||
|
|
@ -229,9 +224,7 @@ class TestExtractRPCResult:
|
|||
"type": "type.googleapis.com/google.internal.labs.tailwind.orchestration.v1.UserDisplayableError",
|
||||
"details": {"code": 1},
|
||||
}
|
||||
chunks = [
|
||||
["wrb.fr", RPCMethod.LIST_NOTEBOOKS.value, None, None, None, error_info]
|
||||
]
|
||||
chunks = [["wrb.fr", RPCMethod.LIST_NOTEBOOKS.value, None, None, None, error_info]]
|
||||
|
||||
with pytest.raises(RPCError, match="rate limiting"):
|
||||
extract_rpc_result(chunks, RPCMethod.LIST_NOTEBOOKS.value)
|
||||
|
|
@ -267,9 +260,7 @@ class TestDecodeResponse:
|
|||
|
||||
def test_decode_complex_nested_data(self):
|
||||
"""Test decoding complex nested data structures."""
|
||||
data = {
|
||||
"notebooks": [{"id": "nb1", "title": "Test", "sources": [{"id": "s1"}]}]
|
||||
}
|
||||
data = {"notebooks": [{"id": "nb1", "title": "Test", "sources": [{"id": "s1"}]}]}
|
||||
inner = json.dumps(data)
|
||||
chunk = json.dumps(["wrb.fr", RPCMethod.LIST_NOTEBOOKS.value, inner, None, None])
|
||||
raw_response = f")]}}'\n{len(chunk)}\n{chunk}\n"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
"""Tests for download helper functions."""
|
||||
|
||||
import pytest
|
||||
from notebooklm.cli.download_helpers import select_artifact, artifact_title_to_filename
|
||||
|
||||
from notebooklm.cli.download_helpers import artifact_title_to_filename, select_artifact
|
||||
|
||||
|
||||
class TestSelectArtifact:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
"""Unit tests for RPC request encoder."""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from urllib.parse import unquote
|
||||
|
||||
from notebooklm.rpc.encoder import encode_rpc_request, build_request_body, build_url_params
|
||||
from notebooklm.rpc.encoder import build_request_body, build_url_params, encode_rpc_request
|
||||
from notebooklm.rpc.types import RPCMethod
|
||||
|
||||
|
||||
|
|
@ -134,28 +132,21 @@ class TestBuildUrlParams:
|
|||
|
||||
def test_with_source_path(self):
|
||||
"""Test URL params with custom source path."""
|
||||
result = build_url_params(
|
||||
RPCMethod.GET_NOTEBOOK,
|
||||
source_path="/notebook/abc123"
|
||||
)
|
||||
result = build_url_params(RPCMethod.GET_NOTEBOOK, source_path="/notebook/abc123")
|
||||
|
||||
assert result["rpcids"] == RPCMethod.GET_NOTEBOOK.value
|
||||
assert result["source-path"] == "/notebook/abc123"
|
||||
|
||||
def test_with_session_id(self):
|
||||
"""Test URL params with session ID."""
|
||||
result = build_url_params(
|
||||
RPCMethod.LIST_NOTEBOOKS,
|
||||
session_id="session_12345"
|
||||
)
|
||||
result = build_url_params(RPCMethod.LIST_NOTEBOOKS, session_id="session_12345")
|
||||
|
||||
assert result["f.sid"] == "session_12345"
|
||||
|
||||
def test_with_build_label(self):
|
||||
"""Test URL params with build label."""
|
||||
result = build_url_params(
|
||||
RPCMethod.LIST_NOTEBOOKS,
|
||||
bl="boq_labs-tailwind-frontend_20250101"
|
||||
RPCMethod.LIST_NOTEBOOKS, bl="boq_labs-tailwind-frontend_20250101"
|
||||
)
|
||||
|
||||
assert result["bl"] == "boq_labs-tailwind-frontend_20250101"
|
||||
|
|
@ -166,7 +157,7 @@ class TestBuildUrlParams:
|
|||
RPCMethod.CREATE_NOTEBOOK,
|
||||
source_path="/notebook/xyz789",
|
||||
session_id="sess_abc",
|
||||
bl="build_label_123"
|
||||
bl="build_label_123",
|
||||
)
|
||||
|
||||
assert result["rpcids"] == RPCMethod.CREATE_NOTEBOOK.value
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
"""Unit tests for NotesAPI private helpers and edge cases."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from notebooklm._notes import NotesAPI
|
||||
from notebooklm.types import Note
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -1,16 +1,15 @@
|
|||
"""Tests for path resolution module."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from notebooklm.paths import (
|
||||
get_home_dir,
|
||||
get_storage_path,
|
||||
get_context_path,
|
||||
get_browser_profile_dir,
|
||||
get_context_path,
|
||||
get_home_dir,
|
||||
get_path_info,
|
||||
get_storage_path,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
"""Tests for research functionality."""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from notebooklm import NotebookLMClient
|
||||
from notebooklm.auth import AuthTokens
|
||||
from notebooklm.rpc import RPCMethod
|
||||
|
|
@ -55,9 +57,7 @@ class TestResearch:
|
|||
]
|
||||
|
||||
response_json = json.dumps([[["task_123", task_info]]])
|
||||
chunk = json.dumps(
|
||||
["wrb.fr", RPCMethod.POLL_RESEARCH.value, response_json, None, None]
|
||||
)
|
||||
chunk = json.dumps(["wrb.fr", RPCMethod.POLL_RESEARCH.value, response_json, None, None])
|
||||
response_body = f")]}}'\n{len(chunk)}\n{chunk}\n"
|
||||
|
||||
httpx_mock.add_response(content=response_body.encode(), method="POST")
|
||||
|
|
@ -73,9 +73,7 @@ class TestResearch:
|
|||
@pytest.mark.asyncio
|
||||
async def test_import_research(self, auth_tokens, httpx_mock):
|
||||
response_json = json.dumps([[[["src_new"], "Imported Title"]]])
|
||||
chunk = json.dumps(
|
||||
["wrb.fr", RPCMethod.IMPORT_RESEARCH.value, response_json, None, None]
|
||||
)
|
||||
chunk = json.dumps(["wrb.fr", RPCMethod.IMPORT_RESEARCH.value, response_json, None, None])
|
||||
response_body = f")]}}'\n{len(chunk)}\n{chunk}\n"
|
||||
|
||||
httpx_mock.add_response(content=response_body.encode(), method="POST")
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
"""Unit tests for RPC types and constants."""
|
||||
|
||||
import pytest
|
||||
from notebooklm.rpc.types import (
|
||||
RPCMethod,
|
||||
StudioContentType,
|
||||
BATCHEXECUTE_URL,
|
||||
QUERY_URL,
|
||||
RPCMethod,
|
||||
StudioContentType,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -13,8 +12,7 @@ class TestRPCConstants:
|
|||
def test_batchexecute_url(self):
|
||||
"""Test batchexecute URL is correct."""
|
||||
assert (
|
||||
BATCHEXECUTE_URL
|
||||
== "https://notebooklm.google.com/_/LabsTailwindUi/data/batchexecute"
|
||||
BATCHEXECUTE_URL == "https://notebooklm.google.com/_/LabsTailwindUi/data/batchexecute"
|
||||
)
|
||||
|
||||
def test_query_url(self):
|
||||
|
|
|
|||
|
|
@ -1,18 +1,19 @@
|
|||
"""Unit tests for source status and polling functionality."""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from notebooklm._sources import SourcesAPI
|
||||
from notebooklm.types import (
|
||||
Source,
|
||||
SourceStatus,
|
||||
SourceError,
|
||||
SourceProcessingError,
|
||||
SourceTimeoutError,
|
||||
SourceNotFoundError,
|
||||
SourceProcessingError,
|
||||
SourceStatus,
|
||||
SourceTimeoutError,
|
||||
)
|
||||
from notebooklm._sources import SourcesAPI
|
||||
|
||||
|
||||
class TestSourceStatus:
|
||||
|
|
@ -27,7 +28,7 @@ class TestSourceStatus:
|
|||
def test_status_is_int_enum(self):
|
||||
"""Test that SourceStatus values can be compared with ints."""
|
||||
assert SourceStatus.READY == 2
|
||||
assert 2 == SourceStatus.READY
|
||||
assert SourceStatus.READY == 2
|
||||
|
||||
|
||||
class TestSourceStatusProperties:
|
||||
|
|
@ -235,9 +236,7 @@ class TestWaitForSources:
|
|||
raise SourceNotFoundError(source_id)
|
||||
|
||||
with patch.object(sources_api, "wait_until_ready", side_effect=mock_wait):
|
||||
results = await sources_api.wait_for_sources(
|
||||
"nb_1", ["src_1", "src_2"], timeout=10.0
|
||||
)
|
||||
results = await sources_api.wait_for_sources("nb_1", ["src_1", "src_2"], timeout=10.0)
|
||||
|
||||
assert len(results) == 2
|
||||
assert all(s.is_ready for s in results)
|
||||
|
|
@ -253,6 +252,4 @@ class TestWaitForSources:
|
|||
|
||||
with patch.object(sources_api, "wait_until_ready", side_effect=mock_wait):
|
||||
with pytest.raises(SourceProcessingError):
|
||||
await sources_api.wait_for_sources(
|
||||
"nb_1", ["src_1", "src_2"], timeout=10.0
|
||||
)
|
||||
await sources_api.wait_for_sources("nb_1", ["src_1", "src_2"], timeout=10.0)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"""Unit tests for SourcesAPI file upload pipeline and YouTube detection."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch, mock_open
|
||||
from pathlib import Path
|
||||
|
||||
from notebooklm._sources import SourcesAPI
|
||||
|
||||
|
|
@ -223,7 +223,9 @@ class TestStartResumableUpload:
|
|||
assert body["SOURCE_ID"] == "src_abc"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_resumable_upload_raises_on_missing_url_header(self, sources_api, mock_core):
|
||||
async def test_start_resumable_upload_raises_on_missing_url_header(
|
||||
self, sources_api, mock_core
|
||||
):
|
||||
"""Test that missing upload URL header raises ValueError."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.headers = {} # No x-goog-upload-url
|
||||
|
|
@ -286,7 +288,9 @@ class TestUploadFileStreaming:
|
|||
mock_client.post.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_streaming_includes_correct_headers(self, sources_api, mock_core, tmp_path):
|
||||
async def test_upload_file_streaming_includes_correct_headers(
|
||||
self, sources_api, mock_core, tmp_path
|
||||
):
|
||||
"""Test that streaming upload includes correct headers."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_bytes(b"content")
|
||||
|
|
@ -325,9 +329,7 @@ class TestUploadFileStreaming:
|
|||
mock_client.post.return_value = mock_response
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
await sources_api._upload_file_streaming(
|
||||
"https://upload.example.com", test_file
|
||||
)
|
||||
await sources_api._upload_file_streaming("https://upload.example.com", test_file)
|
||||
|
||||
call_kwargs = mock_client.post.call_args[1]
|
||||
# Content should be a generator, not bytes
|
||||
|
|
@ -337,7 +339,9 @@ class TestUploadFileStreaming:
|
|||
assert b"".join(chunks) == test_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_streaming_raises_on_http_error(self, sources_api, mock_core, tmp_path):
|
||||
async def test_upload_file_streaming_raises_on_http_error(
|
||||
self, sources_api, mock_core, tmp_path
|
||||
):
|
||||
"""Test that HTTP error raises exception."""
|
||||
import httpx
|
||||
|
||||
|
|
@ -354,9 +358,7 @@ class TestUploadFileStreaming:
|
|||
mock_client_cls.return_value = mock_client
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await sources_api._upload_file_streaming(
|
||||
"https://upload.example.com", test_file
|
||||
)
|
||||
await sources_api._upload_file_streaming("https://upload.example.com", test_file)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -1,20 +1,18 @@
|
|||
"""Unit tests for types module dataclasses and parsing."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from notebooklm.types import (
|
||||
Notebook,
|
||||
NotebookDescription,
|
||||
SuggestedTopic,
|
||||
Source,
|
||||
Artifact,
|
||||
GenerationStatus,
|
||||
ReportSuggestion,
|
||||
Note,
|
||||
ConversationTurn,
|
||||
AskResult,
|
||||
ChatMode,
|
||||
ConversationTurn,
|
||||
GenerationStatus,
|
||||
Note,
|
||||
Notebook,
|
||||
NotebookDescription,
|
||||
ReportSuggestion,
|
||||
Source,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -110,7 +108,13 @@ class TestSource:
|
|||
|
||||
def test_from_api_response_nested_format(self):
|
||||
"""Test parsing medium nested format."""
|
||||
data = [[["src_456"], "Nested Source", [None, None, None, None, None, None, None, ["https://example.com"]]]]
|
||||
data = [
|
||||
[
|
||||
["src_456"],
|
||||
"Nested Source",
|
||||
[None, None, None, None, None, None, None, ["https://example.com"]],
|
||||
]
|
||||
]
|
||||
source = Source.from_api_response(data)
|
||||
|
||||
assert source.id == "src_456"
|
||||
|
|
@ -119,7 +123,15 @@ class TestSource:
|
|||
|
||||
def test_from_api_response_deeply_nested(self):
|
||||
"""Test parsing deeply nested format."""
|
||||
data = [[[["src_789"], "Deep Source", [None, None, None, None, None, None, None, ["https://deep.example.com"]]]]]
|
||||
data = [
|
||||
[
|
||||
[
|
||||
["src_789"],
|
||||
"Deep Source",
|
||||
[None, None, None, None, None, None, None, ["https://deep.example.com"]],
|
||||
]
|
||||
]
|
||||
]
|
||||
source = Source.from_api_response(data)
|
||||
|
||||
assert source.id == "src_789"
|
||||
|
|
@ -128,14 +140,30 @@ class TestSource:
|
|||
|
||||
def test_from_api_response_youtube_url(self):
|
||||
"""Test that YouTube URLs are detected."""
|
||||
data = [[[["src_yt"], "YouTube Video", [None, None, None, None, None, None, None, ["https://youtube.com/watch?v=abc"]]]]]
|
||||
data = [
|
||||
[
|
||||
[
|
||||
["src_yt"],
|
||||
"YouTube Video",
|
||||
[None, None, None, None, None, None, None, ["https://youtube.com/watch?v=abc"]],
|
||||
]
|
||||
]
|
||||
]
|
||||
source = Source.from_api_response(data)
|
||||
|
||||
assert source.source_type == "youtube"
|
||||
|
||||
def test_from_api_response_youtu_be_short_url(self):
|
||||
"""Test that youtu.be short URLs are detected."""
|
||||
data = [[[["src_yt2"], "Short Video", [None, None, None, None, None, None, None, ["https://youtu.be/abc"]]]]]
|
||||
data = [
|
||||
[
|
||||
[
|
||||
["src_yt2"],
|
||||
"Short Video",
|
||||
[None, None, None, None, None, None, None, ["https://youtu.be/abc"]],
|
||||
]
|
||||
]
|
||||
]
|
||||
source = Source.from_api_response(data)
|
||||
|
||||
assert source.source_type == "youtube"
|
||||
|
|
@ -165,7 +193,24 @@ class TestArtifact:
|
|||
def test_from_api_response_with_timestamp(self):
|
||||
"""Test parsing artifact with timestamp."""
|
||||
ts = 1704067200
|
||||
data = ["art_123", "Audio", 1, None, 3, None, None, None, None, None, None, None, None, None, None, [ts]]
|
||||
data = [
|
||||
"art_123",
|
||||
"Audio",
|
||||
1,
|
||||
None,
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
[ts],
|
||||
]
|
||||
artifact = Artifact.from_api_response(data)
|
||||
|
||||
assert artifact.created_at is not None
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
"""Unit tests for YouTube URL extraction."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from notebooklm import NotebookLMClient
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue