feat: access ActivityPub client through interfaces to facilitate mocking in unit tests ()

Was facing issues while writing unit tests for federation code. Mocks weren't catching all network calls, because was being out of scope of the mocking infra. Plus, I think we can have more granular tests.

This PR puts the client behind an interface, that can be retrieved from `ctx`. Context doesn't require initialization, as it defaults to the implementation available in-tree. It may be overridden when required (like testing).

## Mechanism

1. Get client factory from `ctx` (factory contains network and crypto parameters that are needed)
2. Initialize client with sender's keys and the receiver's public key
3. Use client as before.

Reviewed-on: https://codeberg.org/forgejo/forgejo/pulls/4853
Reviewed-by: Earl Warren <earl-warren@noreply.codeberg.org>
Co-authored-by: Aravinth Manivannan <realaravinth@batsense.net>
Co-committed-by: Aravinth Manivannan <realaravinth@batsense.net>
This commit is contained in:
Aravinth Manivannan 2024-08-07 05:45:24 +00:00 committed by Earl Warren
parent 1ddf44edd6
commit f9cbea3d6b
6 changed files with 140 additions and 25 deletions

View file

@ -93,6 +93,10 @@ code.gitea.io/gitea/models/user
GetUserEmailsByNames
GetUserNamesByIDs
code.gitea.io/gitea/modules/activitypub
NewContext
Context.APClientFactory
code.gitea.io/gitea/modules/assetfs
Bindata

View file

@ -56,35 +56,23 @@ func containsRequiredHTTPHeaders(method string, headers []string) error {
}
// Client struct
type Client struct {
type ClientFactory struct {
client *http.Client
algs []httpsig.Algorithm
digestAlg httpsig.DigestAlgorithm
getHeaders []string
postHeaders []string
priv *rsa.PrivateKey
pubID string
}
// NewClient function
func NewClient(ctx context.Context, user *user_model.User, pubID string) (c *Client, err error) {
func NewClientFactory() (c *ClientFactory, err error) {
if err = containsRequiredHTTPHeaders(http.MethodGet, setting.Federation.GetHeaders); err != nil {
return nil, err
} else if err = containsRequiredHTTPHeaders(http.MethodPost, setting.Federation.PostHeaders); err != nil {
return nil, err
}
priv, err := GetPrivateKey(ctx, user)
if err != nil {
return nil, err
}
privPem, _ := pem.Decode([]byte(priv))
privParsed, err := x509.ParsePKCS1PrivateKey(privPem.Bytes)
if err != nil {
return nil, err
}
c = &Client{
c = &ClientFactory{
client: &http.Client{
Transport: &http.Transport{
Proxy: proxy.Proxy(),
@ -95,10 +83,47 @@ func NewClient(ctx context.Context, user *user_model.User, pubID string) (c *Cli
digestAlg: httpsig.DigestAlgorithm(setting.Federation.DigestAlgorithm),
getHeaders: setting.Federation.GetHeaders,
postHeaders: setting.Federation.PostHeaders,
}
return c, err
}
type APClientFactory interface {
WithKeys(ctx context.Context, user *user_model.User, pubID string) (APClient, error)
}
// Client struct
type Client struct {
client *http.Client
algs []httpsig.Algorithm
digestAlg httpsig.DigestAlgorithm
getHeaders []string
postHeaders []string
priv *rsa.PrivateKey
pubID string
}
// NewRequest function
func (cf *ClientFactory) WithKeys(ctx context.Context, user *user_model.User, pubID string) (APClient, error) {
priv, err := GetPrivateKey(ctx, user)
if err != nil {
return nil, err
}
privPem, _ := pem.Decode([]byte(priv))
privParsed, err := x509.ParsePKCS1PrivateKey(privPem.Bytes)
if err != nil {
return nil, err
}
c := Client{
client: cf.client,
algs: cf.algs,
digestAlg: cf.digestAlg,
getHeaders: cf.getHeaders,
postHeaders: cf.postHeaders,
priv: privParsed,
pubID: pubID,
}
return c, err
return &c, nil
}
// NewRequest function
@ -185,3 +210,64 @@ func charLimiter(s string, limit int) string {
}
return s
}
type APClient interface {
newRequest(method string, b []byte, to string) (req *http.Request, err error)
Post(b []byte, to string) (resp *http.Response, err error)
Get(to string) (resp *http.Response, err error)
GetBody(uri string) ([]byte, error)
}
// contextKey is a value for use with context.WithValue.
type contextKey struct {
name string
}
// clientFactoryContextKey is a context key. It is used with context.Value() to get the current Food for the context
var (
clientFactoryContextKey = &contextKey{"clientFactory"}
_ APClientFactory = &ClientFactory{}
)
// Context represents an activitypub client factory context
type Context struct {
context.Context
e APClientFactory
}
func NewContext(ctx context.Context, e APClientFactory) *Context {
return &Context{
Context: ctx,
e: e,
}
}
// APClientFactory represents an activitypub client factory
func (ctx *Context) APClientFactory() APClientFactory {
return ctx.e
}
// provides APClientFactory
type GetAPClient interface {
GetClientFactory() APClientFactory
}
// GetClientFactory will get an APClientFactory from this context or returns the default implementation
func GetClientFactory(ctx context.Context) (APClientFactory, error) {
if e := getClientFactory(ctx); e != nil {
return e, nil
}
return NewClientFactory()
}
// getClientFactory will get an APClientFactory from this context or return nil
func getClientFactory(ctx context.Context) APClientFactory {
if clientFactory, ok := ctx.(APClientFactory); ok {
return clientFactory
}
clientFactoryInterface := ctx.Value(clientFactoryContextKey)
if clientFactoryInterface != nil {
return clientFactoryInterface.(GetAPClient).GetClientFactory()
}
return nil
}

View file

@ -64,14 +64,19 @@ Set up a user called "me" for all tests
*/
func TestNewClientReturnsClient(t *testing.T) {
func TestClientCtx(t *testing.T) {
require.NoError(t, unittest.PrepareTestDatabase())
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1})
pubID := "myGpgId"
c, err := NewClient(db.DefaultContext, user, pubID)
cf, err := NewClientFactory()
log.Debug("ClientFactory: %v\nError: %v", cf, err)
require.NoError(t, err)
c, err := cf.WithKeys(db.DefaultContext, user, pubID)
log.Debug("Client: %v\nError: %v", c, err)
require.NoError(t, err)
_ = NewContext(db.DefaultContext, cf)
}
/* TODO: bring this test to work or delete
@ -109,7 +114,9 @@ func TestActivityPubSignedPost(t *testing.T) {
require.NoError(t, unittest.PrepareTestDatabase())
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1})
pubID := "https://example.com/pubID"
c, err := NewClient(db.DefaultContext, user, pubID)
cf, err := NewClientFactory()
require.NoError(t, err)
c, err := cf.WithKeys(db.DefaultContext, user, pubID)
require.NoError(t, err)
expected := "BODY"

View file

@ -99,7 +99,11 @@ func ProcessLikeActivity(ctx context.Context, form any, repositoryID int64) (int
func CreateFederationHostFromAP(ctx context.Context, actorID fm.ActorID) (*forgefed.FederationHost, error) {
actionsUser := user.NewActionsUser()
client, err := activitypub.NewClient(ctx, actionsUser, "no idea where to get key material.")
clientFactory, err := activitypub.GetClientFactory(ctx)
if err != nil {
return nil, err
}
client, err := clientFactory.WithKeys(ctx, actionsUser, "no idea where to get key material.")
if err != nil {
return nil, err
}
@ -153,7 +157,11 @@ func GetFederationHostForURI(ctx context.Context, actorURI string) (*forgefed.Fe
func CreateUserFromAP(ctx context.Context, personID fm.PersonID, federationHostID int64) (*user.User, *user.FederatedUser, error) {
// ToDo: Do we get a publicKeyId from server, repo or owner or repo?
actionsUser := user.NewActionsUser()
client, err := activitypub.NewClient(ctx, actionsUser, "no idea where to get key material.")
clientFactory, err := activitypub.GetClientFactory(ctx)
if err != nil {
return nil, nil, err
}
client, err := clientFactory.WithKeys(ctx, actionsUser, "no idea where to get key material.")
if err != nil {
return nil, nil, err
}
@ -262,7 +270,11 @@ func SendLikeActivities(ctx context.Context, doer user.User, repoID int64) error
likeActivityList = append(likeActivityList, likeActivity)
}
apclient, err := activitypub.NewClient(ctx, &doer, doer.APActorID())
apclientFactory, err := activitypub.GetClientFactory(ctx)
if err != nil {
return err
}
apclient, err := apclientFactory.WithKeys(ctx, &doer, doer.APActorID())
if err != nil {
return err
}

View file

@ -98,7 +98,9 @@ func TestActivityPubPersonInbox(t *testing.T) {
user1, err := user_model.GetUserByName(ctx, username1)
require.NoError(t, err)
user1url := fmt.Sprintf("%s/api/v1/activitypub/user-id/1#main-key", srv.URL)
c, err := activitypub.NewClient(db.DefaultContext, user1, user1url)
cf, err := activitypub.GetClientFactory(ctx)
require.NoError(t, err)
c, err := cf.WithKeys(db.DefaultContext, user1, user1url)
require.NoError(t, err)
user2inboxurl := fmt.Sprintf("%s/api/v1/activitypub/user-id/2/inbox", srv.URL)

View file

@ -140,7 +140,9 @@ func TestActivityPubRepositoryInboxValid(t *testing.T) {
}()
actionsUser := user.NewActionsUser()
repositoryID := 2
c, err := activitypub.NewClient(db.DefaultContext, actionsUser, "not used")
cf, err := activitypub.GetClientFactory(db.DefaultContext)
require.NoError(t, err)
c, err := cf.WithKeys(db.DefaultContext, actionsUser, "not used")
require.NoError(t, err)
repoInboxURL := fmt.Sprintf(
"%s/api/v1/activitypub/repository-id/%v/inbox",
@ -232,7 +234,9 @@ func TestActivityPubRepositoryInboxInvalid(t *testing.T) {
}()
actionsUser := user.NewActionsUser()
repositoryID := 2
c, err := activitypub.NewClient(db.DefaultContext, actionsUser, "not used")
cf, err := activitypub.GetClientFactory(db.DefaultContext)
require.NoError(t, err)
c, err := cf.WithKeys(db.DefaultContext, actionsUser, "not used")
require.NoError(t, err)
repoInboxURL := fmt.Sprintf("%s/api/v1/activitypub/repository-id/%v/inbox",
srv.URL, repositoryID)